mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Update the OSS tutorial (#139956)
Summary: Update the OSS tutorial to use the new aoti_compile_and_package and aoti_load_package APIs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139956 Approved by: https://github.com/angelayi ghstack dependencies: #139955
This commit is contained in:
committed by
PyTorch MergeBot
parent
07ad74635b
commit
63a0d6587e
@ -23,15 +23,16 @@ Model Compilation
|
||||
---------------------------
|
||||
|
||||
Using AOTInductor, you can still author the model in Python. The following
|
||||
example demonstrates how to invoke ``aot_compile`` to transform the model into a
|
||||
example demonstrates how to invoke ``aoti_compile_and_package`` to transform the model into a
|
||||
shared library.
|
||||
|
||||
This API uses ``torch.export`` to capture the model into a computational graph,
|
||||
This API uses ``torch.export.export`` to capture the model into a computational graph,
|
||||
and then uses TorchInductor to generate a .so which can be run in a non-Python
|
||||
environment. For comprehensive details on the ``torch._export.aot_compile``
|
||||
environment. For comprehensive details on the
|
||||
``torch._inductor.aoti_compile_and_package``
|
||||
API, you can refer to the code
|
||||
`here <https://github.com/pytorch/pytorch/blob/92cc52ab0e48a27d77becd37f1683fd442992120/torch/_export/__init__.py#L891-L900C9>`__.
|
||||
For more details on ``torch.export``, you can refer to the :ref:`torch.export docs <torch.export>`.
|
||||
`here <https://github.com/pytorch/pytorch/blob/6ed237e5b528e3b01a7f1b6366b009dc6f30e6d6/torch/_inductor/__init__.py#L38-L105>`__.
|
||||
For more details on ``torch.export.export``, you can refer to the :ref:`torch.export docs <torch.export>`.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -66,35 +67,48 @@ For more details on ``torch.export``, you can refer to the :ref:`torch.export do
|
||||
model = Model().to(device=device)
|
||||
example_inputs=(torch.randn(8, 10, device=device),)
|
||||
batch_dim = torch.export.Dim("batch", min=1, max=1024)
|
||||
so_path = torch._export.aot_compile(
|
||||
model,
|
||||
# [Optional] Specify the first dimension of the input x as dynamic.
|
||||
exported = torch.export.export(model, example_inputs, dynamic_shapes={"x": {0: batch_dim}})
|
||||
# [Note] In this example we directly feed the exported module to aoti_compile_and_package.
|
||||
# Depending on your use case, e.g. if your training platform and inference platform
|
||||
# are different, you may choose to save the exported model using torch.export.save and
|
||||
# then load it back using torch.export.load on your inference platform to run AOT compilation.
|
||||
output_path = torch._inductor.aoti_compile_and_package(
|
||||
exported,
|
||||
example_inputs,
|
||||
# Specify the first dimension of the input x as dynamic
|
||||
dynamic_shapes={"x": {0: batch_dim}},
|
||||
# Specify the generated shared library path
|
||||
options={"aot_inductor.output_path": os.path.join(os.getcwd(), "model.so")},
|
||||
# [Optional] Specify the generated shared library path. If not specified,
|
||||
# the generated artifact is stored in your system temp directory.
|
||||
package_path=os.path.join(os.getcwd(), "model.pt2"),
|
||||
)
|
||||
|
||||
|
||||
In this illustrative example, the ``Dim`` parameter is employed to designate the first dimension of
|
||||
the input variable "x" as dynamic. Notably, the path and name of the compiled library remain unspecified,
|
||||
resulting in the shared library being stored in a temporary directory.
|
||||
To access this path from the C++ side, we save it to a file for later retrieval within the C++ code.
|
||||
|
||||
|
||||
Inference in Python
|
||||
---------------------------
|
||||
There are multiple ways to deploy the compiled artifact for inference, and one of that is using Python.
|
||||
We have provided a convinient utility API in Python ``torch._inductor.aoti_load_package`` for loading
|
||||
and running the artifact, as shown in the following example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "model.pt2"))
|
||||
print(model(torch.randn(8, 10, device=device)))
|
||||
|
||||
|
||||
Inference in C++
|
||||
---------------------------
|
||||
|
||||
Next, we use the following C++ file ``inference.cpp`` to load the shared library generated in the
|
||||
previous step, enabling us to conduct model predictions directly within a C++ environment.
|
||||
|
||||
.. note::
|
||||
|
||||
The following code snippet assumes your system has a CUDA-enabled device and your model was
|
||||
compiled to run on CUDA as shown previously.
|
||||
In the absence of a GPU, it's necessary to make these adjustments in order to run it on a CPU:
|
||||
1. Change ``model_container_runner_cuda.h`` to ``model_container_runner_cpu.h``
|
||||
2. Change ``AOTIModelContainerRunnerCuda`` to ``AOTIModelContainerRunnerCpu``
|
||||
3. Change ``at::kCUDA`` to ``at::kCPU``
|
||||
Next, we use the following example C++ file ``inference.cpp`` to load the compiled artifact,
|
||||
enabling us to conduct model predictions directly within a C++ environment.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
@ -102,22 +116,24 @@ previous step, enabling us to conduct model predictions directly within a C++ en
|
||||
#include <vector>
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
||||
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
|
||||
|
||||
int main() {
|
||||
c10::InferenceMode mode;
|
||||
|
||||
torch::inductor::AOTIModelContainerRunnerCuda runner("model.so");
|
||||
torch::inductor::AOTIModelPackageLoader loader("model.pt2");
|
||||
torch::inductor::AOTIModelContainerRunner* runner = loader.get_runner();
|
||||
// Assume running on CUDA
|
||||
std::vector<torch::Tensor> inputs = {torch::randn({8, 10}, at::kCUDA)};
|
||||
std::vector<torch::Tensor> outputs = runner.run(inputs);
|
||||
std::vector<torch::Tensor> outputs = runner->run(inputs);
|
||||
std::cout << "Result from the first inference:"<< std::endl;
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
// The second inference uses a different batch size and it works because we
|
||||
// specified that dimension as dynamic when compiling model.so.
|
||||
// specified that dimension as dynamic when compiling model.pt2.
|
||||
std::cout << "Result from the second inference:"<< std::endl;
|
||||
std::vector<torch::Tensor> inputs2 = {torch::randn({2, 10}, at::kCUDA)};
|
||||
std::cout << runner.run(inputs2)[0] << std::endl;
|
||||
// Assume running on CUDA
|
||||
std::cout << runner->run({torch::randn({1, 10}, at::kCUDA)})[0] << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -133,10 +149,10 @@ automates the process of invoking ``python model.py`` for AOT compilation of the
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_executable(aoti_example inference.cpp model.so)
|
||||
add_executable(aoti_example inference.cpp model.pt2)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT model.so
|
||||
OUTPUT model.pt2
|
||||
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/model.py
|
||||
DEPENDS model.py
|
||||
)
|
||||
|
Reference in New Issue
Block a user