[benchmarks] Add nativert benchmark (#159922)

Add NativeRT as an option in the PT2 OSS benchmark

```
python ./benchmarks/dynamo/huggingface.py --performance --inference --export-nativert

python ./benchmarks/dynamo/timm_models.py --performance --inference --export-nativert

python ./benchmarks/dynamo/torchbench.py --performance --inference --export-nativert
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159922
Approved by: https://github.com/angelayi
This commit is contained in:
Yiming Zhou
2025-08-08 03:38:28 +00:00
committed by PyTorch MergeBot
parent 2ea40fba84
commit 017259f9c6

View File

@ -21,6 +21,7 @@ import shutil
import signal
import subprocess
import sys
import tempfile
import time
import weakref
from contextlib import contextmanager
@ -41,6 +42,7 @@ import torch._export
import torch.distributed
import torch.multiprocessing as mp
from torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU
from torch._C._nativert import PyModelRunner
from torch._dynamo.profiler import fx_insert_profiling, Profiler
from torch._dynamo.testing import (
dummy_fx_compile,
@ -1100,6 +1102,8 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
frozen_model_iter_fn = export_aot_inductor(
model, example_inputs, args.inductor_compile_mode
)
elif args.export_nativert:
frozen_model_iter_fn = export_nativert(model, example_inputs)
else:
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
@ -1446,6 +1450,38 @@ class AOTInductorModelCache:
return cls.cache.get(weakref.ref(model), (None, 0.0))[1]
class NativeRTCache:
cache: dict[weakref.ref, Any] = {}
@classmethod
def load(cls, model, example_inputs):
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
key = weakref.ref(model)
if key not in cls.cache:
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
example_outputs = model(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs)
combined_args = _combine_args(model, example_args, example_kwargs)
dynamic_shapes = _tree_map_with_path(
_produce_dynamic_shapes_for_export, combined_args
)
ep = torch.export.export(
model, example_args, example_kwargs, dynamic_shapes=dynamic_shapes
)
ep = ep.run_decompositions({})
with tempfile.NamedTemporaryFile(delete=False) as f:
torch.export.pt2_archive._package.package_pt2(
f, exported_programs={"forward": ep}
)
filename = f.name
cls.cache[key] = PyModelRunner(filename, "forward")
return cls.cache[key]
def export(model, example_inputs):
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
@ -1472,6 +1508,16 @@ def export(model, example_inputs):
return opt_export
def export_nativert(model, example_inputs):
optimized = NativeRTCache.load(model, example_inputs)
def opt_nativert(_, example_inputs, collect_outputs=False):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
return optimized.run(*example_args, **example_kwargs)
return opt_nativert
def export_aot_inductor(model, example_inputs, mode):
optimized = AOTInductorModelCache.load(model, example_inputs, mode)
@ -2228,7 +2274,11 @@ class BenchmarkRunner:
try:
model_copy = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model_copy.parameters())
if self.args.export or self.args.export_aot_inductor:
if (
self.args.export
or self.args.export_aot_inductor
or self.args.export_nativert
):
# apply export on module directly
# no need for n iterations
# the logic should be the same to self.model_iter_fn (forward_pass)
@ -2624,7 +2674,7 @@ class BenchmarkRunner:
niters=1,
)
if self.args.export_aot_inductor:
if self.args.export_aot_inductor or self.args.export_nativert:
optimized_model_iter_fn = optimize_ctx
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
@ -3377,6 +3427,11 @@ def parse_args(args=None):
action="store_true",
help="Measure pass rate with Export+AOTInductor",
)
group.add_argument(
"--export-nativert",
action="store_true",
help="Measure pass rate with Export+NativeRT",
)
group.add_argument(
"--xla", action="store_true", help="Compare TorchXLA to eager PyTorch"
)
@ -3818,6 +3873,10 @@ def run(runner, args, original_dir=None):
optimize_ctx = export
experiment = speedup_experiment
output_filename = "export.csv"
elif args.export_nativert:
optimize_ctx = export_nativert
experiment = speedup_experiment
output_filename = "export_nativert.csv"
elif args.xla:
(dev,) = args.devices
os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev]