[benchmark] Add torchscript jit.trace to benchmark option (#161223)

For comparing NativeRT and TorchScript. We add `torchscript-jit-trace` as an option in the benchmark. With this option, we can run trace a model and run inference with the traced module using TorchScript interpreter

```
python ./benchmarks/dynamo/huggingface.py --performance --inference --torchscript-jit-trace

python ./benchmarks/dynamo/timm_models.py --performance --inference --torchscript-jit-trace

python ./benchmarks/dynamo/torchbench.py --performance --inference --torchscript-jit-trace
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161223
Approved by: https://github.com/huydhn
This commit is contained in:
Yiming Zhou
2025-08-22 21:38:25 +00:00
committed by PyTorch MergeBot
parent 2835cc5e91
commit 9d882fd9ff

View File

@ -1103,6 +1103,8 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
)
elif args.export_nativert:
frozen_model_iter_fn = export_nativert(model, example_inputs)
elif args.torchscript_jit_trace:
frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs)
else:
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
@ -1481,6 +1483,28 @@ class NativeRTCache:
return cls.cache[key]
class JitTracedCache:
cache: dict[weakref.ref, Any] = {}
@classmethod
def load(cls, model, example_inputs):
key = weakref.ref(model)
if key not in cls.cache:
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
if example_args:
jit_traced_module = torch.jit.trace(
model, example_inputs=example_args, strict=False
)
else:
jit_traced_module = torch.jit.trace(
model, example_kwarg_inputs=example_kwargs, strict=False
)
cls.cache[key] = jit_traced_module
return cls.cache[key]
def export(model, example_inputs):
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
@ -1527,6 +1551,16 @@ def export_aot_inductor(model, example_inputs, mode):
return opt_aot_inductor
def torchscript_jit_trace(model, example_inputs):
optimized = JitTracedCache.load(model, example_inputs)
def opt_jit_trace(_, example_inputs, collect_outputs=False):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
return optimized(*example_args, **example_kwargs)
return opt_jit_trace
def download_retry_decorator(download_fn):
"""
Decorator function for applying retry logic to a download function.
@ -2277,6 +2311,7 @@ class BenchmarkRunner:
self.args.export
or self.args.export_aot_inductor
or self.args.export_nativert
or self.args.torchscript_jit_trace
):
# apply export on module directly
# no need for n iterations
@ -2673,7 +2708,11 @@ class BenchmarkRunner:
niters=1,
)
if self.args.export_aot_inductor or self.args.export_nativert:
if (
self.args.export_aot_inductor
or self.args.export_nativert
or self.args.torchscript_jit_trace
):
optimized_model_iter_fn = optimize_ctx
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
@ -3431,6 +3470,11 @@ def parse_args(args=None):
action="store_true",
help="Measure pass rate with Export+NativeRT",
)
group.add_argument(
"--torchscript-jit-trace",
action="store_true",
help="Measure pass rate with TorchScript jit.trace",
)
group.add_argument(
"--xla", action="store_true", help="Compare TorchXLA to eager PyTorch"
)
@ -3876,6 +3920,10 @@ def run(runner, args, original_dir=None):
optimize_ctx = export_nativert
experiment = speedup_experiment
output_filename = "export_nativert.csv"
elif args.torchscript_jit_trace:
optimize_ctx = torchscript_jit_trace
experiment = speedup_experiment
output_filename = "torchscript_jit_trace.csv"
elif args.xla:
(dev,) = args.devices
os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev]