mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[benchmarks] Add nativert benchmark (#159922)
Add NativeRT as an option in the PT2 OSS benchmark ``` python ./benchmarks/dynamo/huggingface.py --performance --inference --export-nativert python ./benchmarks/dynamo/timm_models.py --performance --inference --export-nativert python ./benchmarks/dynamo/torchbench.py --performance --inference --export-nativert ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159922 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
2ea40fba84
commit
017259f9c6
@ -21,6 +21,7 @@ import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
@ -41,6 +42,7 @@ import torch._export
|
||||
import torch.distributed
|
||||
import torch.multiprocessing as mp
|
||||
from torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU
|
||||
from torch._C._nativert import PyModelRunner
|
||||
from torch._dynamo.profiler import fx_insert_profiling, Profiler
|
||||
from torch._dynamo.testing import (
|
||||
dummy_fx_compile,
|
||||
@ -1100,6 +1102,8 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
||||
frozen_model_iter_fn = export_aot_inductor(
|
||||
model, example_inputs, args.inductor_compile_mode
|
||||
)
|
||||
elif args.export_nativert:
|
||||
frozen_model_iter_fn = export_nativert(model, example_inputs)
|
||||
else:
|
||||
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
|
||||
|
||||
@ -1446,6 +1450,38 @@ class AOTInductorModelCache:
|
||||
return cls.cache.get(weakref.ref(model), (None, 0.0))[1]
|
||||
|
||||
|
||||
class NativeRTCache:
|
||||
cache: dict[weakref.ref, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def load(cls, model, example_inputs):
|
||||
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
|
||||
|
||||
key = weakref.ref(model)
|
||||
if key not in cls.cache:
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
example_outputs = model(*example_args, **example_kwargs)
|
||||
_register_dataclass_output_as_pytree(example_outputs)
|
||||
|
||||
combined_args = _combine_args(model, example_args, example_kwargs)
|
||||
dynamic_shapes = _tree_map_with_path(
|
||||
_produce_dynamic_shapes_for_export, combined_args
|
||||
)
|
||||
|
||||
ep = torch.export.export(
|
||||
model, example_args, example_kwargs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
ep = ep.run_decompositions({})
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
torch.export.pt2_archive._package.package_pt2(
|
||||
f, exported_programs={"forward": ep}
|
||||
)
|
||||
filename = f.name
|
||||
cls.cache[key] = PyModelRunner(filename, "forward")
|
||||
|
||||
return cls.cache[key]
|
||||
|
||||
|
||||
def export(model, example_inputs):
|
||||
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
|
||||
|
||||
@ -1472,6 +1508,16 @@ def export(model, example_inputs):
|
||||
return opt_export
|
||||
|
||||
|
||||
def export_nativert(model, example_inputs):
|
||||
optimized = NativeRTCache.load(model, example_inputs)
|
||||
|
||||
def opt_nativert(_, example_inputs, collect_outputs=False):
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
return optimized.run(*example_args, **example_kwargs)
|
||||
|
||||
return opt_nativert
|
||||
|
||||
|
||||
def export_aot_inductor(model, example_inputs, mode):
|
||||
optimized = AOTInductorModelCache.load(model, example_inputs, mode)
|
||||
|
||||
@ -2228,7 +2274,11 @@ class BenchmarkRunner:
|
||||
try:
|
||||
model_copy = self.deepcopy_and_maybe_parallelize(model)
|
||||
self.init_optimizer(name, current_device, model_copy.parameters())
|
||||
if self.args.export or self.args.export_aot_inductor:
|
||||
if (
|
||||
self.args.export
|
||||
or self.args.export_aot_inductor
|
||||
or self.args.export_nativert
|
||||
):
|
||||
# apply export on module directly
|
||||
# no need for n iterations
|
||||
# the logic should be the same to self.model_iter_fn (forward_pass)
|
||||
@ -2624,7 +2674,7 @@ class BenchmarkRunner:
|
||||
niters=1,
|
||||
)
|
||||
|
||||
if self.args.export_aot_inductor:
|
||||
if self.args.export_aot_inductor or self.args.export_nativert:
|
||||
optimized_model_iter_fn = optimize_ctx
|
||||
else:
|
||||
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
|
||||
@ -3377,6 +3427,11 @@ def parse_args(args=None):
|
||||
action="store_true",
|
||||
help="Measure pass rate with Export+AOTInductor",
|
||||
)
|
||||
group.add_argument(
|
||||
"--export-nativert",
|
||||
action="store_true",
|
||||
help="Measure pass rate with Export+NativeRT",
|
||||
)
|
||||
group.add_argument(
|
||||
"--xla", action="store_true", help="Compare TorchXLA to eager PyTorch"
|
||||
)
|
||||
@ -3818,6 +3873,10 @@ def run(runner, args, original_dir=None):
|
||||
optimize_ctx = export
|
||||
experiment = speedup_experiment
|
||||
output_filename = "export.csv"
|
||||
elif args.export_nativert:
|
||||
optimize_ctx = export_nativert
|
||||
experiment = speedup_experiment
|
||||
output_filename = "export_nativert.csv"
|
||||
elif args.xla:
|
||||
(dev,) = args.devices
|
||||
os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev]
|
||||
|
Reference in New Issue
Block a user