Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion examples/benchmarks/ort_inference_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,30 @@
"""Micro benchmark example for ONNXRuntime inference performance.

Commands to run:
In-house models:
python3 examples/benchmarks/ort_inference_performance.py
python3 examples/benchmarks/ort_inference_performance.py --model_source in-house

HuggingFace models:
python3 examples/benchmarks/ort_inference_performance.py \
--model_source huggingface --model_identifier bert-base-uncased
python3 examples/benchmarks/ort_inference_performance.py \
--model_source huggingface --model_identifier microsoft/resnet-50
python3 examples/benchmarks/ort_inference_performance.py \
--model_source huggingface --model_identifier deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

Environment variables:
HF_TOKEN: HuggingFace token for gated models (optional)
"""

import argparse

from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger

if __name__ == '__main__':

def run_inhouse_benchmark():
"""Run ORT inference with in-house torchvision models."""
context = BenchmarkRegistry.create_benchmark_context(
'ort-inference', platform=Platform.CUDA, parameters='--pytorch_models resnet50 resnet101 --precision float16'
)
Expand All @@ -21,3 +38,57 @@
benchmark.name, benchmark.return_code, benchmark.result
)
)
return benchmark


def run_huggingface_benchmark(model_identifier, precision='float16', batch_size=32, seq_length=512):
"""Run ORT inference with a HuggingFace model.

Args:
model_identifier: HuggingFace model ID (e.g., 'bert-base-uncased').
precision: Inference precision ('float32', 'float16', 'int8').
batch_size: Batch size for inference.
seq_length: Sequence length for transformer models.
"""
parameters = (
f'--model_source huggingface '
f'--model_identifier {model_identifier} '
f'--precision {precision} '
f'--batch_size {batch_size} '
f'--seq_length {seq_length}'
)

logger.info(f'Running ORT inference benchmark with HuggingFace model: {model_identifier}')

context = BenchmarkRegistry.create_benchmark_context('ort-inference', platform=Platform.CUDA, parameters=parameters)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
return benchmark


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ORT inference benchmark')
parser.add_argument(
'--model_source',
type=str,
default='in-house',
choices=['in-house', 'huggingface'],
help='Source of the model: in-house (default) or huggingface'
)
parser.add_argument(
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
)
parser.add_argument('--precision', type=str, default='float16', choices=['float32', 'float16', 'int8'])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--seq_length', type=int, default=512)
args = parser.parse_args()

if args.model_source == 'huggingface':
run_huggingface_benchmark(args.model_identifier, args.precision, args.batch_size, args.seq_length)
else:
run_inhouse_benchmark()
80 changes: 79 additions & 1 deletion examples/benchmarks/tensorrt_inference_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,30 @@
"""Micro benchmark example for TensorRT inference performance.

Commands to run:
In-house models:
python3 examples/benchmarks/tensorrt_inference_performance.py
python3 examples/benchmarks/tensorrt_inference_performance.py --model_source in-house

HuggingFace models:
python3 examples/benchmarks/tensorrt_inference_performance.py \
--model_source huggingface --model_identifier bert-base-uncased
python3 examples/benchmarks/tensorrt_inference_performance.py \
--model_source huggingface --model_identifier microsoft/resnet-50
python3 examples/benchmarks/tensorrt_inference_performance.py \
--model_source huggingface --model_identifier deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

Environment variables:
HF_TOKEN: HuggingFace token for gated models (optional)
"""

import argparse

from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger

if __name__ == '__main__':

def run_inhouse_benchmark():
"""Run TensorRT inference with in-house torchvision models."""
context = BenchmarkRegistry.create_benchmark_context('tensorrt-inference', platform=Platform.CUDA)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
Expand All @@ -19,3 +36,64 @@
benchmark.name, benchmark.return_code, benchmark.result
)
)
return benchmark


def run_huggingface_benchmark(model_identifier, precision='fp16', batch_size=32, seq_length=512, iterations=2048):
"""Run TensorRT inference with a HuggingFace model.

Args:
model_identifier: HuggingFace model ID (e.g., 'bert-base-uncased').
precision: Inference precision ('fp32', 'fp16', 'int8').
batch_size: Batch size for inference.
seq_length: Sequence length for transformer models.
iterations: Number of inference iterations.
"""
parameters = (
f'--model_source huggingface '
f'--model_identifier {model_identifier} '
f'--precision {precision} '
f'--batch_size {batch_size} '
f'--seq_length {seq_length} '
f'--iterations {iterations}'
)

logger.info(f'Running TensorRT inference benchmark with HuggingFace model: {model_identifier}')

context = BenchmarkRegistry.create_benchmark_context(
'tensorrt-inference', platform=Platform.CUDA, parameters=parameters
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
return benchmark


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TensorRT inference benchmark')
parser.add_argument(
'--model_source',
type=str,
default='in-house',
choices=['in-house', 'huggingface'],
help='Source of the model: in-house (default) or huggingface'
)
parser.add_argument(
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
)
parser.add_argument('--precision', type=str, default='fp16', choices=['fp32', 'fp16', 'int8'])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--seq_length', type=int, default=512)
parser.add_argument('--iterations', type=int, default=2048)
args = parser.parse_args()

if args.model_source == 'huggingface':
run_huggingface_benchmark(
args.model_identifier, args.precision, args.batch_size, args.seq_length, args.iterations
)
else:
run_inhouse_benchmark()
168 changes: 160 additions & 8 deletions superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,30 @@

"""Export PyTorch models to ONNX format."""

import inspect
from pathlib import Path

from packaging import version
import torch.hub
import torch.onnx
import torchvision.models
from transformers import BertConfig, GPT2Config, LlamaConfig

from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel
import traceback

if MixtralBenchmarkModel is not None:
from transformers import MixtralConfig
from superbench.common.utils import logger


class torch2onnxExporter():
"""PyTorch model to ONNX exporter."""
def __init__(self):
"""Constructor."""
from transformers import BertConfig, GPT2Config, LlamaConfig
from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel

self.num_classes = 100
self.lstm_input_size = 256
self.benchmark_models = {
Expand Down Expand Up @@ -129,6 +131,7 @@ def __init__(self):

# Only include Mixtral models if MixtralBenchmarkModel is available
if MixtralBenchmarkModel is not None:
from transformers import MixtralConfig
self.benchmark_models.update(
{
'mixtral-8x7b':
Expand Down Expand Up @@ -270,3 +273,152 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512):
del dummy_input
torch.cuda.empty_cache()
return file_name

Comment thread
Aishwarya-Tonpe marked this conversation as resolved.
def export_huggingface_model(self, model, model_name, batch_size=1, seq_length=512, output_dir=None):
"""Export a HuggingFace model to ONNX format.

Args:
model: HuggingFace model instance to export.
model_name (str): Name for the exported ONNX model file.
batch_size (int): Batch size of input. Defaults to 1.
seq_length (int): Sequence length of input. Defaults to 512.
output_dir (str): Output directory path. If None, uses default path.

Returns:
str: Exported ONNX model file path, or empty string if export fails.
"""
try:
# Use custom output directory if provided
output_path = Path(output_dir) if output_dir else self._onnx_model_path
file_name = str(output_path / (model_name + '.onnx'))

# Put model in eval mode and move to CUDA if available
model.eval()

# Disable cache to avoid DynamicCache issues with ONNX export
if hasattr(model.config, 'use_cache'):
model.config.use_cache = False

if torch.cuda.is_available():
model = model.cuda()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Get model's dtype for inputs
model_dtype = next(model.parameters()).dtype
Comment thread
Aishwarya-Tonpe marked this conversation as resolved.

# Detect model type and create appropriate inputs
# Vision models use pixel_values, NLP models use input_ids
# Use HuggingFace's main_input_name property for automatic detection
main_input = getattr(model, 'main_input_name', 'input_ids')
is_vision_model = main_input == 'pixel_values'

if is_vision_model:
# Vision models: use pixel_values (batch_size, channels, height, width)
# Derive C/H/W from model config rather than hard-coding 3x224x224
num_channels = getattr(model.config, 'num_channels', 3)
image_size = getattr(model.config, 'image_size', 224)
if isinstance(image_size, (list, tuple)):
img_h, img_w = image_size[0], image_size[1]
else:
img_h, img_w = image_size, image_size

dummy_input = torch.randn(batch_size, num_channels, img_h, img_w, dtype=model_dtype, device=device)
input_names = ['pixel_values']
dynamic_axes = {'pixel_values': {0: 'batch_size'}, 'output': {0: 'batch_size'}}

# Wrapper for vision models
class VisionModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, pixel_values):
outputs = self.model(pixel_values=pixel_values)
if hasattr(outputs, 'logits'):
return outputs.logits
elif hasattr(outputs, 'last_hidden_state'):
return outputs.last_hidden_state
else:
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs

wrapped_model = VisionModelWrapper(model)
export_args = (dummy_input, )
else:
# NLP models: use input_ids and attention_mask
dummy_input = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
input_names = ['input_ids', 'attention_mask']
dynamic_axes = {
'input_ids': {
0: 'batch_size',
1: 'seq_length'
},
'attention_mask': {
0: 'batch_size',
1: 'seq_length'
},
'output': {
0: 'batch_size',
1: 'seq_length'
},
}
Comment thread
Aishwarya-Tonpe marked this conversation as resolved.

# Wrapper for NLP models
class NLPModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
if hasattr(outputs, 'logits'):
return outputs.logits
elif hasattr(outputs, 'last_hidden_state'):
return outputs.last_hidden_state
else:
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs

wrapped_model = NLPModelWrapper(model)
export_args = (dummy_input, attention_mask)

# Export to ONNX for large models (>2GB), use external data format
model_size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3)
use_external_data = model_size_gb > 2.0

if use_external_data:
logger.info(f'Model size is {model_size_gb:.2f}GB, using external data format for ONNX export')

export_kwargs = {
'opset_version': 14,
'do_constant_folding': True,
'input_names': input_names,
'output_names': ['output'],
'dynamic_axes': dynamic_axes,
}
if use_external_data:
# PyTorch 2.8+ renamed 'use_external_data_format' to 'external_data'
sig = inspect.signature(torch.onnx.export)
if 'external_data' in sig.parameters:
export_kwargs['external_data'] = True
else:
export_kwargs['use_external_data_format'] = True

torch.onnx.export(
wrapped_model,
export_args,
file_name,
**export_kwargs,
)
Comment thread
Aishwarya-Tonpe marked this conversation as resolved.

# Clean up
del dummy_input
if torch.cuda.is_available():
torch.cuda.empty_cache()

return file_name

except Exception as e:
logger.error(f'Failed to export HuggingFace model to ONNX: {str(e)}')
logger.error(traceback.format_exc())
return ''
Loading
Loading