Skip to content

LEDazzio01/Tab-Transformer-Plus-Plus

Repository files navigation

TabTransformer++ for Residual Learning

Python PyTorch scikit-learn License

A novel extension of TabTransformer with gated fusion for residual-based model stacking

Quick StartArchitectureProductionChangelog


Overview

This project implements TabTransformer++, an enhanced transformer architecture designed specifically for tabular data in a residual learning framework. Rather than predicting targets directly, the model learns to correct errors from simpler base models—a powerful technique for competition-winning ensembles.

The Residual Learning Approach

+--------------------+     +------------------------+     +-----------------------+
|    Base Model      |     |    TabTransformer++    |     |   Final Prediction    |
| (HistGBR, XGBoost) | --> |   Predicts Residual    | --> |   Base + Residual     |
|    -> base_pred    |     |        (error)         |     |                       |
+--------------------+     +------------------------+     +-----------------------+

Why residual learning?

  • Base models capture linear/tree patterns efficiently
  • Transformers excel at learning complex feature interactions
  • Combined: each model focuses on what it does best

Novel Architectural Contributions

TabTransformer++ introduces six key innovations over the original TabTransformer:

1. Dual Representation (Tokens + Scalars)

Each feature is represented in two complementary ways:

Type Creation Captures
Token Embedding Quantile bin -> learned vector Discrete patterns, ordinal relationships
Value Embedding Raw scalar -> MLP projection Precise numeric magnitude

Why both? Binning loses precision (1.01 and 1.99 may share a bin), but raw scalars lack pattern-matching power.

2. Learnable Gated Fusion (Safe Initialization)

Per-feature gates control the blend between token and scalar representations:

final_emb[i] = token_emb[i] + sigmoid(gate[i]) * value_emb[i]

Safe Initialization: Gates are initialized to -2.0 (sigmoid ≈ 0.12), biasing the model to rely on stable token embeddings first. This prevents early divergence before the model learns when to trust scalar values.

  • Gates are learned independently for each feature
  • Model adapts to each column's characteristics automatically
  • Low gate → token-dominant (categorical treatment)
  • High gate → scalar-dominant (precise numeric treatment)

3. Per-Token Value MLPs

Each feature gets its own projection network instead of sharing:

Linear(1 -> 64) -> GELU -> Linear(64 -> 64) -> LayerNorm

Allows different transformations for different feature distributions.

4. TokenDrop Regularization (with Inverted Scaling)

During training, randomly zero out feature embeddings (p=0.12):

mask = (random > p)   # per-sample, per-feature
mask[:, 0] = 1.0      # Never drop CLS token
x = x * mask / (1 - p)  # Inverted scaling for magnitude consistency

Prevents over-reliance on any single feature. The inverted scaling maintains expected magnitude between train and test modes (like standard Dropout).

5. CLS Token Aggregation

BERT-style [CLS] token prepended to the sequence:

[CLS, feat_1, feat_2, ..., feat_n, base_pred, dt_pred]

CLS attends to all features and produces the final representation.

6. Pre-LayerNorm Transformer

Uses norm_first=True for more stable training without warmup:

Pre-LN:  x = x + Attention(LayerNorm(x))   [Stable]
Post-LN: x = LayerNorm(x + Attention(x))   [Requires warmup]

Architecture Diagram

                     +-------------------------------------+
                     |     INPUT: T features + 2 meta     |
                     |   (tokens, raw_values) per feature |
                     +-------------------------------------+
                                       |
          +----------------------------+----------------------------+
          |                            |                            |
          v                            v                            v
   +-------------+              +-------------+              +-------------+
   | Feature 1   |              | Feature 2   |     ...      | Feature T   |
   | token->embed|              | token->embed|              | token->embed|
   | value->MLP  |              | value->MLP  |              | value->MLP  |
   | gate fusion |              | gate fusion |              | gate fusion |
   +-------------+              +-------------+              +-------------+
          |                            |                            |
          +----------------------------+----------------------------+
                                       |
                                       v
                            +----------------------+
                            |  Embedding Dropout   |
                            |      (p=0.05)        |
                            +----------------------+
                                       |
                                       v
                            +----------------------+
                            |   Prepend [CLS]      |
                            |      Token           |
                            +----------------------+
                                       |
                                       v
                            +----------------------+
                            |    TokenDrop         |
                            |  (p=0.12, train)     |
                            +----------------------+
                                       |
                                       v
                     +-------------------------------------+
                     |      TRANSFORMER ENCODER            |
                     |  +-------------------------------+  |
                     |  | Layer 1: 4-head attention     |  |
                     |  | + FFN(64->256->64) + PreLN    |  |
                     |  +-------------------------------+  |
                     |  +-------------------------------+  |
                     |  | Layer 2: 4-head attention     |  |
                     |  | + FFN(64->256->64) + PreLN    |  |
                     |  +-------------------------------+  |
                     |  +-------------------------------+  |
                     |  | Layer 3: 4-head attention     |  |
                     |  | + FFN(64->256->64) + PreLN    |  |
                     |  +-------------------------------+  |
                     +-------------------------------------+
                                       |
                                       v
                            +----------------------+
                            |  Extract [CLS]       |
                            |    Embedding         |
                            +----------------------+
                                       |
                                       v
                     +-------------------------------------+
                     |         PREDICTION HEAD             |
                     |  LayerNorm -> Linear(64->192)       |
                     |  -> GELU -> Dropout -> Linear(192->1)|
                     +-------------------------------------+
                                       |
                                       v
                          +---------------------+
                          | Predicted Residual  |
                          |  (robust-scaled)    |
                          +---------------------+

Interpretability Features

TabTransformer++ includes built-in interpretability tools:

Gate Value Visualization

Extract and visualize learned gate values to understand feature treatment:

gate_values = extract_gate_values(model, feature_names)
visualize_gate_values(gate_values)
  • Low gate (near 0): Feature works better as categorical bins
  • High gate (near 1): Feature requires precise scalar values

Token Embedding Visualization

Visualize learned embeddings using t-SNE or PCA:

visualize_token_embeddings(model, tokenizer, feature_idx=0, method='pca')

Shows how the model organizes quantile bins in embedding space, revealing learned semantic relationships.


Why TabTransformer++ Over XGBoost?

Even when RMSE is comparable, TabTransformer++ offers unique advantages:

Capability XGBoost TabTransformer++
Dense Embeddings ❌ No ✅ Each row becomes a learned vector
Multi-Modal Fusion ❌ Cannot combine with images/text ✅ Embeddings fuse with vision/NLP models
Transfer Learning ❌ Must retrain from scratch ✅ Pre-train on large tables, fine-tune on small
Interpretable Gates ❌ Feature importance only ✅ Learn token vs scalar preference per feature
GPU Batch Inference ⚠️ Limited ✅ Native PyTorch batching

The Real Value: TabTransformer++ generates dense embeddings suitable for downstream multi-modal tasks (e.g., combining tabular property data with house images).


Training Pipeline

The notebook implements a complete 5-fold cross-validation pipeline:

Step 1: Base Model Stacking

# HistGradientBoostingRegressor for base predictions (captures non-linearity)
model_base = HistGradientBoostingRegressor(max_iter=100, max_depth=5)

# RandomForest for additional signal  
model_dt = RandomForestRegressor(n_estimators=20, max_depth=8)

# Out-of-fold predictions to prevent leakage
residual = target - base_pred

Why HistGradientBoostingRegressor instead of Ridge?

  • Captures non-linear patterns that linear models miss
  • Leaves purer high-order feature interactions for the Transformer
  • Faster than RandomForest due to histogram-based splits

Step 2: Tabular Tokenization

  • Quantile binning: 32 bins for features, 128 for base_pred, 64 for tree_pred
  • Robust scaling: (x - median) / IQR — resistant to outliers (replaces Z-score)
  • Fit on training fold only (leak-free)

Why Robust Scaling? Z-score (x - mean) / std is sensitive to outliers, which can cause gradient explosions in the scalar path. Robust scaling using median and IQR stabilizes training across all folds.

Step 3: Train TabTransformer++

  • EMA (Polyak averaging): Maintains exponential moving average of weights
  • Huber loss: Robust to outliers
  • AdamW optimizer: With weight decay regularization

Step 4: Isotonic Calibration

Post-training calibration maps z-scored predictions to actual residuals:

iso = IsotonicRegression(out_of_bounds="clip")
iso.fit(preds_z, y_va_raw)
calibrated = iso.predict(preds_z)

Step 5: Final Ensemble

final_prediction = base_pred + calibrated_residual

System Design: Production Deployment

This section outlines how TabTransformer++ fits into a production ML system.

Architecture Overview

┌─────────────────────────────────────────────────────────────────────────────┐
│                           TRAINING PIPELINE                                  │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│   ┌──────────────┐    ┌───────────────────┐    ┌─────────────────────────┐  │
│   │  Raw Data    │───▶│  TabularTokenizer │───▶│  Feature Store          │  │
│   │  (Offline)   │    │  .fit() on TRAIN  │    │  (Serialize tokenizer)  │  │
│   └──────────────┘    └───────────────────┘    └─────────────────────────┘  │
│                              │                                               │
│                              ▼                                               │
│                     ┌─────────────────────┐                                  │
│                     │  TabTransformer++   │                                  │
│                     │  PyTorch Training   │                                  │
│                     └─────────────────────┘                                  │
│                              │                                               │
│                              ▼                                               │
│   ┌─────────────────────────────────────────────────────────────────────┐   │
│   │                     Model Export                                     │   │
│   ├─────────────────────────────────────────────────────────────────────┤   │
│   │  • torch.jit.script() → TorchScript (.pt)                           │   │
│   │  • torch.onnx.export() → ONNX (.onnx)                               │   │
│   │  • TensorRT optimization for NVIDIA GPUs                            │   │
│   └─────────────────────────────────────────────────────────────────────┘   │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│                          INFERENCE PIPELINE                                  │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│   ┌──────────────┐    ┌───────────────────┐    ┌─────────────────────────┐  │
│   │  New Request │───▶│  Feature Store    │───▶│  Tokenizer.transform()  │  │
│   │  (Online)    │    │  (Load tokenizer) │    │  (Consistent binning)   │  │
│   └──────────────┘    └───────────────────┘    └─────────────────────────┘  │
│                                                          │                   │
│                                                          ▼                   │
│                              ┌────────────────────────────────────────────┐  │
│                              │  Inference Runtime                         │  │
│                              ├────────────────────────────────────────────┤  │
│                              │  • ONNX Runtime (CPU/GPU)                  │  │
│                              │  • TensorRT (NVIDIA, <1ms latency)         │  │
│                              │  • TorchServe / Triton Inference Server    │  │
│                              └────────────────────────────────────────────┘  │
│                                                          │                   │
│                                                          ▼                   │
│                              ┌─────────────────────────────────────────┐     │
│                              │  Prediction + Post-Processing           │     │
│                              │  base_pred + calibrated_residual        │     │
│                              └─────────────────────────────────────────┘     │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

Key Production Considerations

1. Tokenizer Serialization to Feature Store

The TabularTokenizer encapsulates learned quantile bins and scaling statistics. For online/offline consistency:

import pickle

# After training
with open("tokenizer.pkl", "wb") as f:
    pickle.dump(tokenizer, f)

# Upload to Feature Store (e.g., Feast, Tecton, SageMaker Feature Store)
feature_store.register_artifact("tabtransformer_tokenizer", "tokenizer.pkl")

Why Feature Store?

  • Ensures identical preprocessing in training and serving
  • Version control for tokenizer artifacts
  • Supports A/B testing different tokenizer configurations

2. Model Export for Low-Latency Inference

# Export to ONNX (cross-platform, optimized inference)
import torch.onnx

model.eval()
dummy_tok = torch.randint(0, 32, (1, num_features))
dummy_val = torch.randn(1, num_features)

torch.onnx.export(
    model,
    (dummy_tok, dummy_val),
    "tabtransformer.onnx",
    input_names=["tokens", "values"],
    output_names=["prediction"],
    dynamic_axes={"tokens": {0: "batch"}, "values": {0: "batch"}},
)

# For NVIDIA GPUs: Convert to TensorRT
# trtexec --onnx=tabtransformer.onnx --saveEngine=tabtransformer.trt --fp16

Inference Latency Targets:

Runtime Hardware Typical Latency
PyTorch CPU 5-20ms
ONNX Runtime CPU 2-8ms
ONNX Runtime GPU 0.5-2ms
TensorRT NVIDIA GPU <1ms

3. Online vs. Offline Feature Consistency

Problem: Training uses batch statistics; serving sees single rows.

Solution: Store computed features, don't recompute at inference.

Feature Type Training Serving
Raw features Compute from source Fetch from Feature Store
Base model predictions OOF predictions Pre-computed daily batch
Tokenized features Batch transform Single-row transform

Preventing Train-Serve Skew:

  1. Tokenizer versioning: Hash tokenizer params, embed in model metadata
  2. Feature validation: Assert feature distributions at inference time
  3. Shadow mode: Run new model in parallel, compare outputs before deployment

4. Deployment Architecture Options

Option A: Batch Prediction (Offline)

Airflow/Prefect → Load Data → Transform → Predict → Write to DB
  • Use for: Daily scoring of large datasets
  • Latency: Hours (acceptable)
  • Cost: Low (spot instances)

Option B: Real-Time API (Online)

API Gateway → Load Balancer → Inference Pod (ONNX/TensorRT) → Response
  • Use for: User-facing predictions
  • Latency: <50ms p99
  • Scaling: Horizontal pod autoscaling

Option C: Streaming (Near Real-Time)

Kafka → Feature Compute → Model Inference → Kafka → Downstream
  • Use for: Event-driven predictions
  • Latency: Seconds
  • Throughput: High (parallelizable)

Installation

# Clone the repository
git clone https://github.com/LEDazzio01/Tab-Transformer-Plus-Plus.git
cd Tab-Transformer-Plus-Plus

# Install dependencies
pip install numpy pandas torch scikit-learn jupyter

Quick Start

Option 1: Command-Line Interface

# Install the package
pip install -e .

# Train on built-in California Housing dataset
ttpp train --dataset cal_housing --epochs 10 --batch_size 1024

# Train on your own CSV data
ttpp train --train_data data/train.csv --target_col price --epochs 20

# Train with explicit train/test split
ttpp train --train_data train.csv --test_data test.csv --target_col target --n_folds 5

Option 2: Jupyter Notebook

jupyter notebook TabTransformer_Residual_Learning.ipynb

The notebook demonstrates the full pipeline using the California Housing dataset.

Option 3: Python API

import pandas as pd
from tab_transformer_plus_plus import (
    TabTransformerGated,
    TabularTokenizer,
    TTConfig,
    ModelFactory,
    Trainer,
    TrainingConfig,
    load_data,
    compute_rmse,
)

# Load your data (or use built-in datasets)
train_df, test_df, target_col, features = load_data(seed=42)

# Fit tokenizer on TRAINING data only (prevents leakage)
tokenizer = TabularTokenizer(n_bins=32, features=features, target=target_col)
tokenizer.fit(train_df)  # Never fit on full dataset!

# Transform data
X_train_tok, X_train_val = tokenizer.transform(train_df)
X_test_tok, X_test_val = tokenizer.transform(test_df)

# Create model with configuration
config = TTConfig(
    n_features=len(features),
    n_bins=32,
    embed_dim=64,
    n_heads=4,
    n_layers=3,
)
model = TabTransformerGated(config)

# Train using the Trainer class
train_config = TrainingConfig(epochs=10, batch_size=1024, learning_rate=2e-3)
trainer = Trainer(model=model, config=train_config)
# ... or use train_tabular for the full residual learning pipeline

Saving and Loading Models

# Save model checkpoint
ModelFactory.save_checkpoint(model, config, "model.pt")

# Load model checkpoint
model, config = ModelFactory.from_checkpoint("model.pt")

# Save/load tokenizer
tokenizer.save("tokenizer.pkl")
loaded_tokenizer = TabularTokenizer.load("tokenizer.pkl")

Custom Base Models

Register your own base models using the factory pattern:

from tab_transformer_plus_plus import BaseModelFactory
from sklearn.linear_model import Ridge

# Register a custom model
BaseModelFactory.register("ridge", Ridge)

# Use in training
config = BaseModelConfig(model_type="ridge", hyperparams={"alpha": 1.0})

Training Callbacks

from tab_transformer_plus_plus import (
    EarlyStoppingCallback,
    LRSchedulerCallback,
    CheckpointCallback,
)

callbacks = [
    EarlyStoppingCallback(patience=5, min_delta=0.001),
    LRSchedulerCallback(scheduler_type="cosine"),
    CheckpointCallback(save_dir="checkpoints/", save_best_only=True),
]

trainer = Trainer(model=model, config=train_config, callbacks=callbacks)

Hyperparameters

All hyperparameters are centralized in the TTConfig and TrainingConfig classes:

Category Parameter Default Description
Tokenization n_bins 32 Quantile bins for numeric features
Architecture embed_dim 64 Embedding dimension (d_model)
n_heads 4 Multi-head attention heads
n_layers 3 Transformer encoder layers
mlp_hidden 192 Prediction head hidden dim
Regularization dropout 0.1 Attention & FFN dropout
emb_dropout 0.05 Post-embedding dropout
tokendrop_p 0.12 TokenDrop probability
Training epochs 10 Training epochs
batch_size 1024 Batch size
learning_rate 2e-3 AdamW learning rate
weight_decay 0.01 AdamW weight decay

Access default values from constants:

from tab_transformer_plus_plus import (
    DEFAULT_EPOCHS,
    DEFAULT_BATCH_SIZE,
    DEFAULT_LEARNING_RATE,
    DEFAULT_N_BINS,
    GATE_INIT_VALUE,
)

File Structure

Tab-Transformer-Plus-Plus/
├── README.md                              # This documentation
├── CHANGELOG.md                           # Version history and changes
├── LICENSE                                # MIT License
├── pyproject.toml                         # Package configuration
├── requirements.txt                       # Dependencies
├── TabTransformer_Residual_Learning.ipynb # Interactive notebook demo
├── src/
│   └── tab_transformer_plus_plus/
│       ├── __init__.py                    # Package exports and public API
│       ├── base_models.py                 # BaseModelFactory and ensemble logic
│       ├── cli.py                         # Command-line interface
│       ├── configs.py                     # TTConfig, TrainingConfig, etc.
│       ├── constants.py                   # Default values and magic numbers
│       ├── data_loader.py                 # Data loading and splitting utilities
│       ├── exceptions.py                  # Custom exception hierarchy
│       ├── metrics.py                     # MetricRegistry and compute functions
│       ├── model.py                       # TabTransformerGated model (vectorized)
│       ├── protocols.py                   # Protocol classes for type safety
│       ├── tokenizer.py                   # TabularTokenizer (with serialization)
│       ├── train.py                       # High-level training pipeline
│       ├── trainer.py                     # Trainer class with callbacks
│       └── utils.py                       # Utility and visualization functions
└── tests/
    ├── test_cli.py                        # CLI argument parsing tests
    ├── test_integration.py                # End-to-end integration tests
    ├── test_model.py                      # Model architecture tests
    ├── test_tokenizer.py                  # Tokenizer edge case tests
    └── test_utils.py                      # Utility function tests

Citation

If you use this code, please cite the original TabTransformer paper:

@article{huang2020tabtransformer,
  title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
  author={Huang, Xin and Khetan, Ashish and Cvitkovic, Milan and Karnin, Zohar},
  journal={arXiv preprint arXiv:2012.06678},
  year={2020}
}

Acknowledgments


License

This project is licensed under the MIT License. See LICENSE for details.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published