[nativert] aoti (#162353)

Summary: att

Test Plan:
ci

Rollback Plan:

Differential Revision: D81731425

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162353
Approved by: https://github.com/yiming0416
This commit is contained in:
dolpm
2025-09-12 05:56:19 +00:00
committed by PyTorch MergeBot
parent 28e8531032
commit 30e16d6389
18 changed files with 581 additions and 5 deletions

View File

@ -638,10 +638,13 @@ libtorch_nativert_sources = [
"torch/nativert/kernels/KernelHandlerRegistry.cpp",
"torch/nativert/kernels/TritonKernel.cpp",
"torch/nativert/executor/triton/CpuTritonKernelManager.cpp",
"torch/nativert/executor/AOTInductorDelegateExecutor.cpp",
"torch/nativert/kernels/ETCallDelegateKernel.cpp",
]
libtorch_nativert_cuda_sources = [
"torch/nativert/executor/triton/CudaTritonKernelManager.cpp",
"torch/nativert/executor/AOTInductorModelContainerCudaShim.cpp",
]
torch_mobile_tracer_sources = [

17
docs/source/nativert.rst Normal file
View File

@ -0,0 +1,17 @@
torch.nativert
==============
.. automodule:: torch.nativert
.. currentmodule:: torch.nativert
.. py:module:: torch.nativert
:noindex:
torch.nativert.backends
-----------------------
.. automodule:: torch.nativert.backends
.. currentmodule:: torch.nativert.backends
.. py:module:: torch.nativert.backends
:noindex:

View File

@ -56,6 +56,7 @@ torch.monitor <monitor>
torch.signal <signal>
torch.special <special>
torch.overrides
torch.nativert <nativert>
torch.package <package>
profiler
nn.init

View File

@ -43,10 +43,14 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
${TORCH_ROOT}/torch/nativert/executor/AOTInductorDelegateExecutor.cpp
${TORCH_ROOT}/torch/nativert/kernels/ETCallDelegateKernel.cpp
${TORCH_ROOT}/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp
)
if(USE_CUDA)
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/AOTInductorModelContainerCudaShim.cpp)
endif()
add_executable(test_nativert

View File

@ -0,0 +1,16 @@
#include <gtest/gtest.h>
#include <torch/nativert/executor/AOTInductorDelegateExecutor.h>
using namespace ::testing;
using namespace torch::nativert;
TEST(AOTIModelContainerRegistrationTests, TestRegister) {
EXPECT_TRUE(AOTIModelContainerRunnerRegistry()->Has(at::kCPU));
#ifdef USE_CUDA
EXPECT_TRUE(AOTIModelContainerRunnerRegistry()->Has(at::kCUDA));
#else
EXPECT_FALSE(AOTIModelContainerRunnerRegistry()->Has(at::kCUDA));
#endif // USE_CUDA
}

View File

@ -6,9 +6,19 @@ import pathlib
import tempfile
import unittest
from parameterized import parameterized
import torch
import torch._dynamo as torchdynamo
from torch._C._nativert import PyModelRunner
from torch._dynamo.test_case import TestCase
from torch._subclasses.fake_tensor import FakeTensor
from torch.nativert.backends._lower_utils import (
lower_exported_program,
package_nativert_with_aoti_delegate,
)
from torch.testing._internal.common_utils import IS_WINDOWS
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils import _pytree as pytree
@ -185,6 +195,153 @@ def make_dynamic_cls(cls, strict=False):
test_class.__module__ = __name__
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestNativeRT(TestCase):
@staticmethod
def get_module():
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 4)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.linear(x))
return M()
@staticmethod
def get_module_multi_output():
class MMultiOutput(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 4)
self.relu = torch.nn.ReLU()
def forward(self, x):
return (self.relu(self.linear(x)), x)
return MMultiOutput()
@staticmethod
def get_model_pytree():
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
def forward(self, x):
x1, (x2, x3) = x
y1 = self.linear1(x1)
y2 = self.linear2(x2)
y3 = self.linear2(x3)
return (y1, (y2, y3))
return M()
parameters = []
for device in ["cpu", "cuda"]:
if device == "cuda" and not HAS_GPU:
continue
for module, sample_inputs in [
(get_module.__func__().to(device), (torch.randn(4, 4).to(device),)),
(
get_module_multi_output.__func__().to(device),
(torch.randn(4, 4).to(device),),
),
(
get_model_pytree.__func__().to(device),
(
(
torch.randn(4, 4).to(device),
(
torch.randn(4, 4).to(device),
torch.randn(4, 4).to(device),
),
),
),
),
]:
parameters.append(
(
device,
module,
sample_inputs,
)
)
@parameterized.expand(parameters)
def test_aoti(self, device, m, sample_inputs):
MODEL_NAME = "model"
BACKEND_ID = "aoti"
# get the original EP
original_ep = torch.export.export(m, sample_inputs)
aoti_delegate_ep, aoti_files = lower_exported_program(
original_ep, MODEL_NAME, BACKEND_ID
)
# package everything needed for the NativeRT to execute the AOTI delegate
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
package_nativert_with_aoti_delegate(
f,
MODEL_NAME,
BACKEND_ID,
original_ep,
aoti_delegate_ep,
aoti_files,
)
filename = f.name
try:
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
)
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) == type(expected)
if isinstance(result, torch.Tensor) and isinstance(
expected, torch.Tensor
):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
finally:
pathlib.Path(filename).unlink(missing_ok=True)
tests = [
test_export.TestExport,
]

View File

@ -85,6 +85,7 @@ ModelRunner::ModelRunner(
weights->validateAllWeightsLoaded();
torch::nativert::ExecutorConfig config;
config.modelName = modelName;
executor_ = std::make_unique<Executor>(
config, graph_, std::move(weights), pytorchStreamReader);

View File

View File

@ -1,4 +0,0 @@
from .lowered_aoti_module import LoweredBackendModule
__all__ = ["LoweredBackendModule"]

View File

@ -6,7 +6,7 @@ from torch.export import ExportedProgram
from torch.export.pt2_archive._package import AOTI_FILES, package_pt2
from torch.types import FileLike
from .lowered_aoti_module import LoweredBackendModule
from ._lowered_aoti_module import LoweredBackendModule
def get_new_ep_with_flat_inputs_outputs(ep: ExportedProgram) -> ExportedProgram:

View File

@ -0,0 +1,168 @@
#include <torch/nativert/executor/AOTInductorDelegateExecutor.h>
#include <ATen/core/Tensor.h>
#include <ATen/record_function.h>
#include <c10/util/Logging.h>
#include <torch/csrc/export/pt2_archive_constants.h>
#include <torch/csrc/utils/generated_serialization_types.h>
#include <torch/nativert/executor/Weights.h>
namespace torch::nativert {
#ifndef NATIVERT_MSVC_TEST
C10_DEFINE_TYPED_REGISTRY(
AOTIModelContainerRunnerRegistry,
c10::DeviceType,
torch::inductor::AOTIModelContainerRunner,
std::unique_ptr,
const std::string&,
size_t,
const std::string&,
const std::string&,
const bool)
#endif // NATIVERT_MSVC_TEST
namespace {
template <typename T>
std::optional<at::ScalarType> parse_precision(
const std::optional<T>& precision) {
if (precision) {
return static_cast<at::ScalarType>(*precision);
}
return std::nullopt;
}
c10::Device infer_target_device(const Node& node) {
std::vector<c10::Device> devices;
const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta();
for (const auto* output : node.outputs()) {
if (auto it = tensorValuesMeta.find(std::string{output->name()});
it != tensorValuesMeta.end()) {
devices.emplace_back(it->second.device());
}
}
TORCH_CHECK(!devices.empty(), "AOTI node should have at least one output");
for (const auto i : c10::irange(1, devices.size())) {
if (!torch::nativert::isSameDevice(devices[0], devices[i])) {
LOG(WARNING) << "Node " << node
<< " has outputs on multiple devices: " << devices[0]
<< " and " << devices[i];
}
}
return devices[0];
}
std::unique_ptr<torch::inductor::AOTIModelContainerRunner>
create_aoti_model_container_runner_cpu(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir,
const bool run_single_threaded) {
return std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
model_so_path,
num_models,
/* run_single_threaded= */ run_single_threaded);
}
} // namespace
C10_REGISTER_TYPED_CREATOR(
AOTIModelContainerRunnerRegistry,
at::kCPU,
create_aoti_model_container_runner_cpu)
AOTIDelegateExecutor::AOTIDelegateExecutor(
const Node& node,
const std::shared_ptr<Weights>& weights,
const ExecutorConfig& executorConfig,
caffe2::serialize::PyTorchStreamReader* packageReader,
const MakeProxyExecutorFn& makeProxyExecutorFunc)
: ETDelegateExecutor(torch::_export::archive_spec::AOTINDUCTOR_DIR, node) {
TORCH_CHECK(
packageReader, "Package reader cannot be null for lowered modules");
auto path = get_delegate_dir() + "/";
LOG(INFO) << "Loading aotinductor model from archive path: " << path;
std::optional<std::string> model_name = std::nullopt;
for (const auto& record : packageReader->getAllRecords()) {
if (c10::starts_with(record, path) && c10::ends_with(record, ".so")) {
model_name = record.substr(record.find_last_of("/\\") + 1);
break;
}
}
TORCH_CHECK(model_name.has_value(), "missing model .so in archive: ", path);
path.pop_back(); // remove trailing slash
std::string tmp_dir = extractToTemporaryFolder(*packageReader, path);
LOG(INFO) << "Extracted aot_inductor model to: " << tmp_dir;
std::string model_path = tmp_dir + "/" + *model_name;
LOG(INFO) << "Loading aotinductor model from model path: " << model_path;
auto device = infer_target_device(node);
LOG(INFO) << "Creating AOTI model container runner with device "
<< device.str();
aoti_model_container_runner_ = AOTIModelContainerRunnerRegistry()->Create(
device.type(),
model_path,
/* num_models= */ executorConfig.maxNumConcurrentThreads,
device.str(),
/*cubin_dir=*/tmp_dir,
/*run_single_threaded=*/false);
for (const auto& [name, original_fqn] :
aoti_model_container_runner_->getConstantNamesToOriginalFQNs()) {
if (weights->contains(original_fqn)) {
weight_names_map_[original_fqn] = name;
} else {
LOG(WARNING)
<< "AOTI's Constant " << original_fqn
<< " is not found in weights, it's likely a constant created by AOTI constant folding. "
<< "Valid weight FQNs are " << weights->toString();
}
}
// AOTI's DelegateExecutor doesn't need to call processWeights or
// commitWeights here because it's invoked from Executor's ctor already.
}
void AOTIDelegateExecutor::initWeights(std::shared_ptr<Weights> weights) {
// Do nothing for AOTI, as AOTI's .so already contains the weights.
LOG(INFO)
<< "Skipping initWeights for AOTI to use original weights from .so file.";
}
void AOTIDelegateExecutor::processWeights(std::shared_ptr<Weights> weights) {
LOG(INFO) << "AOTIDelegateExecutor processing weights";
std::unordered_map<std::string, at::Tensor*> new_weights;
for (const auto& [original_fqn, name] : weight_names_map_) {
new_weights.emplace(name, &weights->at(original_fqn));
}
aoti_model_container_runner_->update_inactive_constant_buffer(new_weights);
aoti_model_container_runner_->run_const_fold(/*use_inactive=*/true);
}
void AOTIDelegateExecutor::commitWeights() {
LOG(INFO) << "AOTIDelegateExecutor committing weights";
aoti_model_container_runner_->swap_constant_buffer();
}
std::vector<at::Tensor> AOTIDelegateExecutor::run(
std::vector<at::Tensor>& inputs) {
RECORD_USER_SCOPE("sigmoid::AOTIDelegateExecutor::run");
std::vector<at::Tensor> outputs = aoti_model_container_runner_->run(inputs);
return outputs;
}
} // namespace torch::nativert

View File

@ -0,0 +1,49 @@
#pragma once
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#include <torch/nativert/executor/ETDelegateExecutor.h>
#include <torch/nativert/executor/ExecutorConfig.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
namespace torch::nativert {
class AOTIDelegateExecutor : public ETDelegateExecutor {
public:
explicit AOTIDelegateExecutor(
const Node& node,
const std::shared_ptr<Weights>& weights,
const ExecutorConfig& executorConfig,
caffe2::serialize::PyTorchStreamReader* packageReader,
const MakeProxyExecutorFn& makeProxyExecutorFunc);
~AOTIDelegateExecutor() override = default;
void processWeights(std::shared_ptr<Weights> weights) override;
void initWeights(std::shared_ptr<Weights> weights) override;
void commitWeights() override;
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) override;
private:
std::unique_ptr<torch::inductor::AOTIModelContainerRunner>
aoti_model_container_runner_;
// key is weight's original fqn, value is weight's name in AOTI
std::unordered_map<std::string, std::string> weight_names_map_;
};
C10_DECLARE_TYPED_REGISTRY(
AOTIModelContainerRunnerRegistry,
c10::DeviceType,
torch::inductor::AOTIModelContainerRunner,
std::unique_ptr,
const std::string&,
size_t,
const std::string&,
const std::string&,
const bool);
} // namespace torch::nativert

View File

@ -0,0 +1,24 @@
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#include <torch/nativert/executor/AOTInductorDelegateExecutor.h>
namespace torch::nativert {
namespace {
std::unique_ptr<torch::inductor::AOTIModelContainerRunner>
create_aoti_model_container_runner_cuda(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir,
const bool run_single_threaded) {
return std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
model_so_path, num_models, device_str, cubin_dir, run_single_threaded);
}
} // namespace
C10_REGISTER_TYPED_CREATOR(
AOTIModelContainerRunnerRegistry,
at::kCUDA,
create_aoti_model_container_runner_cuda)
} // namespace torch::nativert

View File

@ -0,0 +1,34 @@
#pragma once
#include <torch/nativert/executor/DelegateExecutor.h>
#include <torch/nativert/executor/ExecutorConfig.h>
namespace torch::nativert {
class ETDelegateExecutor : public DelegateExecutor {
public:
explicit ETDelegateExecutor(
const std::string_view& dir_prefix,
const Node& node)
: delegate_dir_([&]() {
const std::string* path =
std::get_if<std::string>(&node.attributes()[0].value);
TORCH_CHECK(
path != nullptr,
"et hop's first attribute should correspond to it's path");
return std::string(dir_prefix) + *path;
}()) {
VLOG(1) << "ETDelegateExecutor: " << delegate_dir_;
}
~ETDelegateExecutor() override = default;
const std::string& get_delegate_dir() {
return delegate_dir_;
}
private:
std::string delegate_dir_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,43 @@
#include <torch/nativert/kernels/ETCallDelegateKernel.h>
#include <torch/nativert/executor/ETDelegateExecutor.h>
namespace torch::nativert {
ETCallDelegateKernel::ETCallDelegateKernel(
const Node* node,
ETDelegateExecutor& delegateExecutor)
: OpKernel(node), delegateExecutor_(delegateExecutor) {
for (const auto& input : node_->inputs()) {
TORCH_CHECK(input.value->type() == Type::Kind::Tensor);
}
for (const auto* output : node_->outputs()) {
TORCH_CHECK(output->type() == Type::Kind::Tensor);
}
}
void ETCallDelegateKernel::computeInternal(
ExecutionFrame& executionFrame) const {
std::vector<at::Tensor> inputs;
inputs.reserve(numInputs());
for (const auto& input : node_->inputs()) {
inputs.emplace_back(executionFrame.getTensor(input.value->id()));
}
auto outputs = delegateExecutor_.run(inputs);
const auto& node_outputs = node_->outputs();
TORCH_CHECK(outputs.size() == node_outputs.size());
size_t i = 0;
for (auto begin = std::make_move_iterator(outputs.begin()),
end = std::make_move_iterator(outputs.end());
begin != end;
++begin) {
executionFrame.setIValue(node_outputs[i]->id(), *begin);
i++;
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,22 @@
#pragma once
#include <torch/nativert/executor/ExecutionFrame.h>
#include <torch/nativert/executor/OpKernel.h>
namespace torch::nativert {
class ETDelegateExecutor;
class ETCallDelegateKernel : public OpKernel {
public:
explicit ETCallDelegateKernel(
const Node* node,
ETDelegateExecutor& delegateExecutor);
void computeInternal(ExecutionFrame& executionFrame) const override final;
private:
ETDelegateExecutor& delegateExecutor_;
};
} // namespace torch::nativert

View File

@ -12,6 +12,10 @@
#include <torch/nativert/kernels/KernelFactory.h>
#include <torch/nativert/kernels/KernelRegistry.h>
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
#include <torch/nativert/executor/AOTInductorDelegateExecutor.h>
#include <torch/nativert/kernels/ETCallDelegateKernel.h>
namespace torch::nativert {
namespace {
@ -31,6 +35,14 @@ std::string maybeRevisedStaticDispatchTarget(const Node& node) {
}
return std::string(node.target());
}
std::unique_ptr<torch::aot_inductor::ProxyExecutor> make_proxy_executor(
const std::string& filename,
bool is_cpu,
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs) {
return std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
filename, is_cpu, std::move(custom_objs));
}
} // namespace
void register_kernel_handlers() {
@ -62,6 +74,35 @@ void register_kernel_handlers() {
->Create(maybeRevisedStaticDispatchTarget(node), &node),
nullptr};
}));
KernelFactory::registerHandler(
"et_delegate",
KernelFactoryHandler(
[](const Node& node,
const torch::nativert::ExecutorConfig& /* executorConfig */) {
return c10::starts_with(
node.target(),
"torch.ops.higher_order.executorch_call_delegate");
},
[](const Node& node,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::shared_ptr<Weights> weights,
const torch::nativert::ExecutorConfig& executorConfig,
caffe2::serialize::PyTorchStreamReader* packageReader)
-> std::pair<
KernelFactoryHandler::OpKernelPtr,
KernelFactoryHandler::DelegateExecutorPtr> {
auto delegateExecutor = std::make_unique<AOTIDelegateExecutor>(
node,
weights,
executorConfig,
packageReader,
make_proxy_executor);
return {
std::make_unique<ETCallDelegateKernel>(
&node, *delegateExecutor),
std::move(delegateExecutor)};
}));
});
}