Update TorchDynamo-based ONNX Exporter memory usage example code. (#144139)

Address related comments earlier.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144139
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
Jay Zhang
2025-01-03 20:41:34 +00:00
committed by PyTorch MergeBot
parent 64bffb3124
commit b75f32b848
2 changed files with 9 additions and 8 deletions

View File

@ -21,7 +21,7 @@ The main advantage of this approach is that the `FX graph <https://pytorch.org/d
bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
In addition, during the export process, memory usage is significantly reduced compared to the TorchScript-enabled exporter.
See the :doc:`documentation <onnx_dynamo_memory_usage>` for more information.
See the :doc:`memory usage documentation <onnx_dynamo_memory_usage>` for more information.
The exporter is designed to be modular and extensible. It is composed of the following components:

View File

@ -4,9 +4,9 @@ The previous TorchScript-based ONNX exporter would execute the model once to tra
memory on your GPU if the model's memory requirements exceeded the available GPU memory. This issue has been addressed with the new
TorchDynamo-based ONNX exporter.
The TorchDynamo-based ONNX exporter leverages `FakeTensorMode <https://pytorch.org/docs/stable/torch.compiler_fake_tensor.html>`_ to
avoid performing actual tensor computations during the export process. This approach results in significantly lower memory usage
compared to the TorchScript-based ONNX exporter.
The TorchDynamo-based ONNX exporter utilizes torch.export.export() function to leverage
`FakeTensorMode <https://pytorch.org/docs/stable/torch.compiler_fake_tensor.html>`_ to avoid performing actual tensor computations
during the export process. This approach results in significantly lower memory usage compared to the TorchScript-based ONNX exporter.
Below is an example demonstrating the memory usage difference between TorchScript-based and TorchDynamo-based ONNX exporters.
In this example, we use the HighResNet model from MONAI. Before proceeding, please install it from PyPI:
@ -29,7 +29,6 @@ The code below could be run to generate a snapshot file which records the state
import torch
from torch.onnx.utils import export
from monai.networks.nets import (
HighResNet,
)
@ -44,17 +43,19 @@ The code below could be run to generate a snapshot file which records the state
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")
with torch.no_grad():
export(
onnx_program = torch.onnx.export(
model,
data,
"torchscript_exporter_highresnet.onnx",
dynamo=False,
)
snapshot_name = f"torchscript_exporter_example.pickle"
snapshot_name = "torchscript_exporter_example.pickle"
print(f"generate {snapshot_name}")
torch.cuda.memory._dump_snapshot(snapshot_name)
print(f"Export is done.")
print("Export is done.")
Open `pytorch.org/memory_viz <https://pytorch.org/memory_viz>`_ and drag/drop the generated pickled snapshot file into the visualizer.
The memory usage is described as below: