mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[benchmark] Add torchscript jit.trace to benchmark option (#161223)
For comparing NativeRT and TorchScript. We add `torchscript-jit-trace` as an option in the benchmark. With this option, we can run trace a model and run inference with the traced module using TorchScript interpreter ``` python ./benchmarks/dynamo/huggingface.py --performance --inference --torchscript-jit-trace python ./benchmarks/dynamo/timm_models.py --performance --inference --torchscript-jit-trace python ./benchmarks/dynamo/torchbench.py --performance --inference --torchscript-jit-trace ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161223 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
2835cc5e91
commit
9d882fd9ff
@ -1103,6 +1103,8 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
||||
)
|
||||
elif args.export_nativert:
|
||||
frozen_model_iter_fn = export_nativert(model, example_inputs)
|
||||
elif args.torchscript_jit_trace:
|
||||
frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs)
|
||||
else:
|
||||
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
|
||||
|
||||
@ -1481,6 +1483,28 @@ class NativeRTCache:
|
||||
return cls.cache[key]
|
||||
|
||||
|
||||
class JitTracedCache:
|
||||
cache: dict[weakref.ref, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def load(cls, model, example_inputs):
|
||||
key = weakref.ref(model)
|
||||
if key not in cls.cache:
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
if example_args:
|
||||
jit_traced_module = torch.jit.trace(
|
||||
model, example_inputs=example_args, strict=False
|
||||
)
|
||||
else:
|
||||
jit_traced_module = torch.jit.trace(
|
||||
model, example_kwarg_inputs=example_kwargs, strict=False
|
||||
)
|
||||
|
||||
cls.cache[key] = jit_traced_module
|
||||
|
||||
return cls.cache[key]
|
||||
|
||||
|
||||
def export(model, example_inputs):
|
||||
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
|
||||
|
||||
@ -1527,6 +1551,16 @@ def export_aot_inductor(model, example_inputs, mode):
|
||||
return opt_aot_inductor
|
||||
|
||||
|
||||
def torchscript_jit_trace(model, example_inputs):
|
||||
optimized = JitTracedCache.load(model, example_inputs)
|
||||
|
||||
def opt_jit_trace(_, example_inputs, collect_outputs=False):
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
return optimized(*example_args, **example_kwargs)
|
||||
|
||||
return opt_jit_trace
|
||||
|
||||
|
||||
def download_retry_decorator(download_fn):
|
||||
"""
|
||||
Decorator function for applying retry logic to a download function.
|
||||
@ -2277,6 +2311,7 @@ class BenchmarkRunner:
|
||||
self.args.export
|
||||
or self.args.export_aot_inductor
|
||||
or self.args.export_nativert
|
||||
or self.args.torchscript_jit_trace
|
||||
):
|
||||
# apply export on module directly
|
||||
# no need for n iterations
|
||||
@ -2673,7 +2708,11 @@ class BenchmarkRunner:
|
||||
niters=1,
|
||||
)
|
||||
|
||||
if self.args.export_aot_inductor or self.args.export_nativert:
|
||||
if (
|
||||
self.args.export_aot_inductor
|
||||
or self.args.export_nativert
|
||||
or self.args.torchscript_jit_trace
|
||||
):
|
||||
optimized_model_iter_fn = optimize_ctx
|
||||
else:
|
||||
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
|
||||
@ -3431,6 +3470,11 @@ def parse_args(args=None):
|
||||
action="store_true",
|
||||
help="Measure pass rate with Export+NativeRT",
|
||||
)
|
||||
group.add_argument(
|
||||
"--torchscript-jit-trace",
|
||||
action="store_true",
|
||||
help="Measure pass rate with TorchScript jit.trace",
|
||||
)
|
||||
group.add_argument(
|
||||
"--xla", action="store_true", help="Compare TorchXLA to eager PyTorch"
|
||||
)
|
||||
@ -3876,6 +3920,10 @@ def run(runner, args, original_dir=None):
|
||||
optimize_ctx = export_nativert
|
||||
experiment = speedup_experiment
|
||||
output_filename = "export_nativert.csv"
|
||||
elif args.torchscript_jit_trace:
|
||||
optimize_ctx = torchscript_jit_trace
|
||||
experiment = speedup_experiment
|
||||
output_filename = "torchscript_jit_trace.csv"
|
||||
elif args.xla:
|
||||
(dev,) = args.devices
|
||||
os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev]
|
||||
|
Reference in New Issue
Block a user