[AOTI] Update the OSS tutorial (#139956)

Summary: Update the OSS tutorial to use the new aoti_compile_and_package and aoti_load_package APIs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139956
Approved by: https://github.com/angelayi
ghstack dependencies: #139955
This commit is contained in:
Bin Bao
2024-11-06 17:54:02 -08:00
committed by PyTorch MergeBot
parent 07ad74635b
commit 63a0d6587e

View File

@ -23,15 +23,16 @@ Model Compilation
---------------------------
Using AOTInductor, you can still author the model in Python. The following
example demonstrates how to invoke ``aot_compile`` to transform the model into a
example demonstrates how to invoke ``aoti_compile_and_package`` to transform the model into a
shared library.
This API uses ``torch.export`` to capture the model into a computational graph,
This API uses ``torch.export.export`` to capture the model into a computational graph,
and then uses TorchInductor to generate a .so which can be run in a non-Python
environment. For comprehensive details on the ``torch._export.aot_compile``
environment. For comprehensive details on the
``torch._inductor.aoti_compile_and_package``
API, you can refer to the code
`here <https://github.com/pytorch/pytorch/blob/92cc52ab0e48a27d77becd37f1683fd442992120/torch/_export/__init__.py#L891-L900C9>`__.
For more details on ``torch.export``, you can refer to the :ref:`torch.export docs <torch.export>`.
`here <https://github.com/pytorch/pytorch/blob/6ed237e5b528e3b01a7f1b6366b009dc6f30e6d6/torch/_inductor/__init__.py#L38-L105>`__.
For more details on ``torch.export.export``, you can refer to the :ref:`torch.export docs <torch.export>`.
.. note::
@ -66,35 +67,48 @@ For more details on ``torch.export``, you can refer to the :ref:`torch.export do
model = Model().to(device=device)
example_inputs=(torch.randn(8, 10, device=device),)
batch_dim = torch.export.Dim("batch", min=1, max=1024)
so_path = torch._export.aot_compile(
model,
# [Optional] Specify the first dimension of the input x as dynamic.
exported = torch.export.export(model, example_inputs, dynamic_shapes={"x": {0: batch_dim}})
# [Note] In this example we directly feed the exported module to aoti_compile_and_package.
# Depending on your use case, e.g. if your training platform and inference platform
# are different, you may choose to save the exported model using torch.export.save and
# then load it back using torch.export.load on your inference platform to run AOT compilation.
output_path = torch._inductor.aoti_compile_and_package(
exported,
example_inputs,
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
# Specify the generated shared library path
options={"aot_inductor.output_path": os.path.join(os.getcwd(), "model.so")},
# [Optional] Specify the generated shared library path. If not specified,
# the generated artifact is stored in your system temp directory.
package_path=os.path.join(os.getcwd(), "model.pt2"),
)
In this illustrative example, the ``Dim`` parameter is employed to designate the first dimension of
the input variable "x" as dynamic. Notably, the path and name of the compiled library remain unspecified,
resulting in the shared library being stored in a temporary directory.
To access this path from the C++ side, we save it to a file for later retrieval within the C++ code.
Inference in Python
---------------------------
There are multiple ways to deploy the compiled artifact for inference, and one of that is using Python.
We have provided a convinient utility API in Python ``torch._inductor.aoti_load_package`` for loading
and running the artifact, as shown in the following example:
.. code-block:: python
import os
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "model.pt2"))
print(model(torch.randn(8, 10, device=device)))
Inference in C++
---------------------------
Next, we use the following C++ file ``inference.cpp`` to load the shared library generated in the
previous step, enabling us to conduct model predictions directly within a C++ environment.
.. note::
The following code snippet assumes your system has a CUDA-enabled device and your model was
compiled to run on CUDA as shown previously.
In the absence of a GPU, it's necessary to make these adjustments in order to run it on a CPU:
1. Change ``model_container_runner_cuda.h`` to ``model_container_runner_cpu.h``
2. Change ``AOTIModelContainerRunnerCuda`` to ``AOTIModelContainerRunnerCpu``
3. Change ``at::kCUDA`` to ``at::kCPU``
Next, we use the following example C++ file ``inference.cpp`` to load the compiled artifact,
enabling us to conduct model predictions directly within a C++ environment.
.. code-block:: cpp
@ -102,22 +116,24 @@ previous step, enabling us to conduct model predictions directly within a C++ en
#include <vector>
#include <torch/torch.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
int main() {
c10::InferenceMode mode;
torch::inductor::AOTIModelContainerRunnerCuda runner("model.so");
torch::inductor::AOTIModelPackageLoader loader("model.pt2");
torch::inductor::AOTIModelContainerRunner* runner = loader.get_runner();
// Assume running on CUDA
std::vector<torch::Tensor> inputs = {torch::randn({8, 10}, at::kCUDA)};
std::vector<torch::Tensor> outputs = runner.run(inputs);
std::vector<torch::Tensor> outputs = runner->run(inputs);
std::cout << "Result from the first inference:"<< std::endl;
std::cout << outputs[0] << std::endl;
// The second inference uses a different batch size and it works because we
// specified that dimension as dynamic when compiling model.so.
// specified that dimension as dynamic when compiling model.pt2.
std::cout << "Result from the second inference:"<< std::endl;
std::vector<torch::Tensor> inputs2 = {torch::randn({2, 10}, at::kCUDA)};
std::cout << runner.run(inputs2)[0] << std::endl;
// Assume running on CUDA
std::cout << runner->run({torch::randn({1, 10}, at::kCUDA)})[0] << std::endl;
return 0;
}
@ -133,10 +149,10 @@ automates the process of invoking ``python model.py`` for AOT compilation of the
find_package(Torch REQUIRED)
add_executable(aoti_example inference.cpp model.so)
add_executable(aoti_example inference.cpp model.pt2)
add_custom_command(
OUTPUT model.so
OUTPUT model.pt2
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/model.py
DEPENDS model.py
)