mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[nativert] aoti (#162353)
Summary: att Test Plan: ci Rollback Plan: Differential Revision: D81731425 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162353 Approved by: https://github.com/yiming0416
This commit is contained in:
@ -638,10 +638,13 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/kernels/KernelHandlerRegistry.cpp",
|
||||
"torch/nativert/kernels/TritonKernel.cpp",
|
||||
"torch/nativert/executor/triton/CpuTritonKernelManager.cpp",
|
||||
"torch/nativert/executor/AOTInductorDelegateExecutor.cpp",
|
||||
"torch/nativert/kernels/ETCallDelegateKernel.cpp",
|
||||
]
|
||||
|
||||
libtorch_nativert_cuda_sources = [
|
||||
"torch/nativert/executor/triton/CudaTritonKernelManager.cpp",
|
||||
"torch/nativert/executor/AOTInductorModelContainerCudaShim.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
|
17
docs/source/nativert.rst
Normal file
17
docs/source/nativert.rst
Normal file
@ -0,0 +1,17 @@
|
||||
torch.nativert
|
||||
==============
|
||||
|
||||
.. automodule:: torch.nativert
|
||||
.. currentmodule:: torch.nativert
|
||||
|
||||
.. py:module:: torch.nativert
|
||||
:noindex:
|
||||
|
||||
torch.nativert.backends
|
||||
-----------------------
|
||||
|
||||
.. automodule:: torch.nativert.backends
|
||||
.. currentmodule:: torch.nativert.backends
|
||||
|
||||
.. py:module:: torch.nativert.backends
|
||||
:noindex:
|
@ -56,6 +56,7 @@ torch.monitor <monitor>
|
||||
torch.signal <signal>
|
||||
torch.special <special>
|
||||
torch.overrides
|
||||
torch.nativert <nativert>
|
||||
torch.package <package>
|
||||
profiler
|
||||
nn.init
|
||||
|
@ -43,10 +43,14 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/AOTInductorDelegateExecutor.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/ETCallDelegateKernel.cpp
|
||||
${TORCH_ROOT}/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp
|
||||
)
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
|
||||
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/AOTInductorModelContainerCudaShim.cpp)
|
||||
endif()
|
||||
|
||||
add_executable(test_nativert
|
||||
|
16
test/cpp/nativert/test_aoti_model_container_registration.cpp
Normal file
16
test/cpp/nativert/test_aoti_model_container_registration.cpp
Normal file
@ -0,0 +1,16 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/nativert/executor/AOTInductorDelegateExecutor.h>
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace torch::nativert;
|
||||
|
||||
TEST(AOTIModelContainerRegistrationTests, TestRegister) {
|
||||
EXPECT_TRUE(AOTIModelContainerRunnerRegistry()->Has(at::kCPU));
|
||||
|
||||
#ifdef USE_CUDA
|
||||
EXPECT_TRUE(AOTIModelContainerRunnerRegistry()->Has(at::kCUDA));
|
||||
#else
|
||||
EXPECT_FALSE(AOTIModelContainerRunnerRegistry()->Has(at::kCUDA));
|
||||
#endif // USE_CUDA
|
||||
}
|
@ -6,9 +6,19 @@ import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch._C._nativert import PyModelRunner
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.nativert.backends._lower_utils import (
|
||||
lower_exported_program,
|
||||
package_nativert_with_aoti_delegate,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
@ -185,6 +195,153 @@ def make_dynamic_cls(cls, strict=False):
|
||||
test_class.__module__ = __name__
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestNativeRT(TestCase):
|
||||
@staticmethod
|
||||
def get_module():
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.linear(x))
|
||||
|
||||
return M()
|
||||
|
||||
@staticmethod
|
||||
def get_module_multi_output():
|
||||
class MMultiOutput(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return (self.relu(self.linear(x)), x)
|
||||
|
||||
return MMultiOutput()
|
||||
|
||||
@staticmethod
|
||||
def get_model_pytree():
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x1, (x2, x3) = x
|
||||
y1 = self.linear1(x1)
|
||||
y2 = self.linear2(x2)
|
||||
y3 = self.linear2(x3)
|
||||
return (y1, (y2, y3))
|
||||
|
||||
return M()
|
||||
|
||||
parameters = []
|
||||
for device in ["cpu", "cuda"]:
|
||||
if device == "cuda" and not HAS_GPU:
|
||||
continue
|
||||
for module, sample_inputs in [
|
||||
(get_module.__func__().to(device), (torch.randn(4, 4).to(device),)),
|
||||
(
|
||||
get_module_multi_output.__func__().to(device),
|
||||
(torch.randn(4, 4).to(device),),
|
||||
),
|
||||
(
|
||||
get_model_pytree.__func__().to(device),
|
||||
(
|
||||
(
|
||||
torch.randn(4, 4).to(device),
|
||||
(
|
||||
torch.randn(4, 4).to(device),
|
||||
torch.randn(4, 4).to(device),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
]:
|
||||
parameters.append(
|
||||
(
|
||||
device,
|
||||
module,
|
||||
sample_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand(parameters)
|
||||
def test_aoti(self, device, m, sample_inputs):
|
||||
MODEL_NAME = "model"
|
||||
BACKEND_ID = "aoti"
|
||||
|
||||
# get the original EP
|
||||
original_ep = torch.export.export(m, sample_inputs)
|
||||
|
||||
aoti_delegate_ep, aoti_files = lower_exported_program(
|
||||
original_ep, MODEL_NAME, BACKEND_ID
|
||||
)
|
||||
|
||||
# package everything needed for the NativeRT to execute the AOTI delegate
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
|
||||
package_nativert_with_aoti_delegate(
|
||||
f,
|
||||
MODEL_NAME,
|
||||
BACKEND_ID,
|
||||
original_ep,
|
||||
aoti_delegate_ep,
|
||||
aoti_files,
|
||||
)
|
||||
filename = f.name
|
||||
|
||||
try:
|
||||
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
|
||||
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) == type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(
|
||||
expected, torch.Tensor
|
||||
):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
pathlib.Path(filename).unlink(missing_ok=True)
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestExport,
|
||||
]
|
||||
|
@ -85,6 +85,7 @@ ModelRunner::ModelRunner(
|
||||
weights->validateAllWeightsLoaded();
|
||||
|
||||
torch::nativert::ExecutorConfig config;
|
||||
config.modelName = modelName;
|
||||
|
||||
executor_ = std::make_unique<Executor>(
|
||||
config, graph_, std::move(weights), pytorchStreamReader);
|
||||
|
0
torch/nativert/__init__.py
Normal file
0
torch/nativert/__init__.py
Normal file
@ -1,4 +0,0 @@
|
||||
from .lowered_aoti_module import LoweredBackendModule
|
||||
|
||||
|
||||
__all__ = ["LoweredBackendModule"]
|
||||
|
@ -6,7 +6,7 @@ from torch.export import ExportedProgram
|
||||
from torch.export.pt2_archive._package import AOTI_FILES, package_pt2
|
||||
from torch.types import FileLike
|
||||
|
||||
from .lowered_aoti_module import LoweredBackendModule
|
||||
from ._lowered_aoti_module import LoweredBackendModule
|
||||
|
||||
|
||||
def get_new_ep_with_flat_inputs_outputs(ep: ExportedProgram) -> ExportedProgram:
|
168
torch/nativert/executor/AOTInductorDelegateExecutor.cpp
Normal file
168
torch/nativert/executor/AOTInductorDelegateExecutor.cpp
Normal file
@ -0,0 +1,168 @@
|
||||
#include <torch/nativert/executor/AOTInductorDelegateExecutor.h>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include <torch/csrc/export/pt2_archive_constants.h>
|
||||
#include <torch/csrc/utils/generated_serialization_types.h>
|
||||
#include <torch/nativert/executor/Weights.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
#ifndef NATIVERT_MSVC_TEST
|
||||
C10_DEFINE_TYPED_REGISTRY(
|
||||
AOTIModelContainerRunnerRegistry,
|
||||
c10::DeviceType,
|
||||
torch::inductor::AOTIModelContainerRunner,
|
||||
std::unique_ptr,
|
||||
const std::string&,
|
||||
size_t,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const bool)
|
||||
#endif // NATIVERT_MSVC_TEST
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
std::optional<at::ScalarType> parse_precision(
|
||||
const std::optional<T>& precision) {
|
||||
if (precision) {
|
||||
return static_cast<at::ScalarType>(*precision);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
c10::Device infer_target_device(const Node& node) {
|
||||
std::vector<c10::Device> devices;
|
||||
|
||||
const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta();
|
||||
for (const auto* output : node.outputs()) {
|
||||
if (auto it = tensorValuesMeta.find(std::string{output->name()});
|
||||
it != tensorValuesMeta.end()) {
|
||||
devices.emplace_back(it->second.device());
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(!devices.empty(), "AOTI node should have at least one output");
|
||||
for (const auto i : c10::irange(1, devices.size())) {
|
||||
if (!torch::nativert::isSameDevice(devices[0], devices[i])) {
|
||||
LOG(WARNING) << "Node " << node
|
||||
<< " has outputs on multiple devices: " << devices[0]
|
||||
<< " and " << devices[i];
|
||||
}
|
||||
}
|
||||
|
||||
return devices[0];
|
||||
}
|
||||
|
||||
std::unique_ptr<torch::inductor::AOTIModelContainerRunner>
|
||||
create_aoti_model_container_runner_cpu(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded) {
|
||||
return std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
|
||||
model_so_path,
|
||||
num_models,
|
||||
/* run_single_threaded= */ run_single_threaded);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
C10_REGISTER_TYPED_CREATOR(
|
||||
AOTIModelContainerRunnerRegistry,
|
||||
at::kCPU,
|
||||
create_aoti_model_container_runner_cpu)
|
||||
|
||||
AOTIDelegateExecutor::AOTIDelegateExecutor(
|
||||
const Node& node,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
const ExecutorConfig& executorConfig,
|
||||
caffe2::serialize::PyTorchStreamReader* packageReader,
|
||||
const MakeProxyExecutorFn& makeProxyExecutorFunc)
|
||||
: ETDelegateExecutor(torch::_export::archive_spec::AOTINDUCTOR_DIR, node) {
|
||||
TORCH_CHECK(
|
||||
packageReader, "Package reader cannot be null for lowered modules");
|
||||
|
||||
auto path = get_delegate_dir() + "/";
|
||||
|
||||
LOG(INFO) << "Loading aotinductor model from archive path: " << path;
|
||||
|
||||
std::optional<std::string> model_name = std::nullopt;
|
||||
for (const auto& record : packageReader->getAllRecords()) {
|
||||
if (c10::starts_with(record, path) && c10::ends_with(record, ".so")) {
|
||||
model_name = record.substr(record.find_last_of("/\\") + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(model_name.has_value(), "missing model .so in archive: ", path);
|
||||
path.pop_back(); // remove trailing slash
|
||||
|
||||
std::string tmp_dir = extractToTemporaryFolder(*packageReader, path);
|
||||
LOG(INFO) << "Extracted aot_inductor model to: " << tmp_dir;
|
||||
|
||||
std::string model_path = tmp_dir + "/" + *model_name;
|
||||
|
||||
LOG(INFO) << "Loading aotinductor model from model path: " << model_path;
|
||||
|
||||
auto device = infer_target_device(node);
|
||||
LOG(INFO) << "Creating AOTI model container runner with device "
|
||||
<< device.str();
|
||||
|
||||
aoti_model_container_runner_ = AOTIModelContainerRunnerRegistry()->Create(
|
||||
device.type(),
|
||||
model_path,
|
||||
/* num_models= */ executorConfig.maxNumConcurrentThreads,
|
||||
device.str(),
|
||||
/*cubin_dir=*/tmp_dir,
|
||||
/*run_single_threaded=*/false);
|
||||
|
||||
for (const auto& [name, original_fqn] :
|
||||
aoti_model_container_runner_->getConstantNamesToOriginalFQNs()) {
|
||||
if (weights->contains(original_fqn)) {
|
||||
weight_names_map_[original_fqn] = name;
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "AOTI's Constant " << original_fqn
|
||||
<< " is not found in weights, it's likely a constant created by AOTI constant folding. "
|
||||
<< "Valid weight FQNs are " << weights->toString();
|
||||
}
|
||||
}
|
||||
|
||||
// AOTI's DelegateExecutor doesn't need to call processWeights or
|
||||
// commitWeights here because it's invoked from Executor's ctor already.
|
||||
}
|
||||
|
||||
void AOTIDelegateExecutor::initWeights(std::shared_ptr<Weights> weights) {
|
||||
// Do nothing for AOTI, as AOTI's .so already contains the weights.
|
||||
LOG(INFO)
|
||||
<< "Skipping initWeights for AOTI to use original weights from .so file.";
|
||||
}
|
||||
|
||||
void AOTIDelegateExecutor::processWeights(std::shared_ptr<Weights> weights) {
|
||||
LOG(INFO) << "AOTIDelegateExecutor processing weights";
|
||||
std::unordered_map<std::string, at::Tensor*> new_weights;
|
||||
for (const auto& [original_fqn, name] : weight_names_map_) {
|
||||
new_weights.emplace(name, &weights->at(original_fqn));
|
||||
}
|
||||
|
||||
aoti_model_container_runner_->update_inactive_constant_buffer(new_weights);
|
||||
aoti_model_container_runner_->run_const_fold(/*use_inactive=*/true);
|
||||
}
|
||||
|
||||
void AOTIDelegateExecutor::commitWeights() {
|
||||
LOG(INFO) << "AOTIDelegateExecutor committing weights";
|
||||
aoti_model_container_runner_->swap_constant_buffer();
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> AOTIDelegateExecutor::run(
|
||||
std::vector<at::Tensor>& inputs) {
|
||||
RECORD_USER_SCOPE("sigmoid::AOTIDelegateExecutor::run");
|
||||
std::vector<at::Tensor> outputs = aoti_model_container_runner_->run(inputs);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
49
torch/nativert/executor/AOTInductorDelegateExecutor.h
Normal file
49
torch/nativert/executor/AOTInductorDelegateExecutor.h
Normal file
@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
|
||||
#include <torch/nativert/executor/ETDelegateExecutor.h>
|
||||
#include <torch/nativert/executor/ExecutorConfig.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
||||
#endif
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class AOTIDelegateExecutor : public ETDelegateExecutor {
|
||||
public:
|
||||
explicit AOTIDelegateExecutor(
|
||||
const Node& node,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
const ExecutorConfig& executorConfig,
|
||||
caffe2::serialize::PyTorchStreamReader* packageReader,
|
||||
const MakeProxyExecutorFn& makeProxyExecutorFunc);
|
||||
~AOTIDelegateExecutor() override = default;
|
||||
|
||||
void processWeights(std::shared_ptr<Weights> weights) override;
|
||||
void initWeights(std::shared_ptr<Weights> weights) override;
|
||||
void commitWeights() override;
|
||||
|
||||
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<torch::inductor::AOTIModelContainerRunner>
|
||||
aoti_model_container_runner_;
|
||||
|
||||
// key is weight's original fqn, value is weight's name in AOTI
|
||||
std::unordered_map<std::string, std::string> weight_names_map_;
|
||||
};
|
||||
|
||||
C10_DECLARE_TYPED_REGISTRY(
|
||||
AOTIModelContainerRunnerRegistry,
|
||||
c10::DeviceType,
|
||||
torch::inductor::AOTIModelContainerRunner,
|
||||
std::unique_ptr,
|
||||
const std::string&,
|
||||
size_t,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const bool);
|
||||
|
||||
} // namespace torch::nativert
|
@ -0,0 +1,24 @@
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
||||
#include <torch/nativert/executor/AOTInductorDelegateExecutor.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
std::unique_ptr<torch::inductor::AOTIModelContainerRunner>
|
||||
create_aoti_model_container_runner_cuda(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded) {
|
||||
return std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
|
||||
model_so_path, num_models, device_str, cubin_dir, run_single_threaded);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
C10_REGISTER_TYPED_CREATOR(
|
||||
AOTIModelContainerRunnerRegistry,
|
||||
at::kCUDA,
|
||||
create_aoti_model_container_runner_cuda)
|
||||
|
||||
} // namespace torch::nativert
|
34
torch/nativert/executor/ETDelegateExecutor.h
Normal file
34
torch/nativert/executor/ETDelegateExecutor.h
Normal file
@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nativert/executor/DelegateExecutor.h>
|
||||
#include <torch/nativert/executor/ExecutorConfig.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class ETDelegateExecutor : public DelegateExecutor {
|
||||
public:
|
||||
explicit ETDelegateExecutor(
|
||||
const std::string_view& dir_prefix,
|
||||
const Node& node)
|
||||
: delegate_dir_([&]() {
|
||||
const std::string* path =
|
||||
std::get_if<std::string>(&node.attributes()[0].value);
|
||||
TORCH_CHECK(
|
||||
path != nullptr,
|
||||
"et hop's first attribute should correspond to it's path");
|
||||
return std::string(dir_prefix) + *path;
|
||||
}()) {
|
||||
VLOG(1) << "ETDelegateExecutor: " << delegate_dir_;
|
||||
}
|
||||
|
||||
~ETDelegateExecutor() override = default;
|
||||
|
||||
const std::string& get_delegate_dir() {
|
||||
return delegate_dir_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string delegate_dir_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
43
torch/nativert/kernels/ETCallDelegateKernel.cpp
Normal file
43
torch/nativert/kernels/ETCallDelegateKernel.cpp
Normal file
@ -0,0 +1,43 @@
|
||||
#include <torch/nativert/kernels/ETCallDelegateKernel.h>
|
||||
|
||||
#include <torch/nativert/executor/ETDelegateExecutor.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
ETCallDelegateKernel::ETCallDelegateKernel(
|
||||
const Node* node,
|
||||
ETDelegateExecutor& delegateExecutor)
|
||||
: OpKernel(node), delegateExecutor_(delegateExecutor) {
|
||||
for (const auto& input : node_->inputs()) {
|
||||
TORCH_CHECK(input.value->type() == Type::Kind::Tensor);
|
||||
}
|
||||
|
||||
for (const auto* output : node_->outputs()) {
|
||||
TORCH_CHECK(output->type() == Type::Kind::Tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void ETCallDelegateKernel::computeInternal(
|
||||
ExecutionFrame& executionFrame) const {
|
||||
std::vector<at::Tensor> inputs;
|
||||
inputs.reserve(numInputs());
|
||||
|
||||
for (const auto& input : node_->inputs()) {
|
||||
inputs.emplace_back(executionFrame.getTensor(input.value->id()));
|
||||
}
|
||||
|
||||
auto outputs = delegateExecutor_.run(inputs);
|
||||
const auto& node_outputs = node_->outputs();
|
||||
TORCH_CHECK(outputs.size() == node_outputs.size());
|
||||
|
||||
size_t i = 0;
|
||||
for (auto begin = std::make_move_iterator(outputs.begin()),
|
||||
end = std::make_move_iterator(outputs.end());
|
||||
begin != end;
|
||||
++begin) {
|
||||
executionFrame.setIValue(node_outputs[i]->id(), *begin);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
22
torch/nativert/kernels/ETCallDelegateKernel.h
Normal file
22
torch/nativert/kernels/ETCallDelegateKernel.h
Normal file
@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nativert/executor/ExecutionFrame.h>
|
||||
#include <torch/nativert/executor/OpKernel.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class ETDelegateExecutor;
|
||||
|
||||
class ETCallDelegateKernel : public OpKernel {
|
||||
public:
|
||||
explicit ETCallDelegateKernel(
|
||||
const Node* node,
|
||||
ETDelegateExecutor& delegateExecutor);
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
|
||||
private:
|
||||
ETDelegateExecutor& delegateExecutor_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
@ -12,6 +12,10 @@
|
||||
#include <torch/nativert/kernels/KernelFactory.h>
|
||||
#include <torch/nativert/kernels/KernelRegistry.h>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
|
||||
#include <torch/nativert/executor/AOTInductorDelegateExecutor.h>
|
||||
#include <torch/nativert/kernels/ETCallDelegateKernel.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
@ -31,6 +35,14 @@ std::string maybeRevisedStaticDispatchTarget(const Node& node) {
|
||||
}
|
||||
return std::string(node.target());
|
||||
}
|
||||
|
||||
std::unique_ptr<torch::aot_inductor::ProxyExecutor> make_proxy_executor(
|
||||
const std::string& filename,
|
||||
bool is_cpu,
|
||||
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs) {
|
||||
return std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
|
||||
filename, is_cpu, std::move(custom_objs));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void register_kernel_handlers() {
|
||||
@ -62,6 +74,35 @@ void register_kernel_handlers() {
|
||||
->Create(maybeRevisedStaticDispatchTarget(node), &node),
|
||||
nullptr};
|
||||
}));
|
||||
KernelFactory::registerHandler(
|
||||
"et_delegate",
|
||||
KernelFactoryHandler(
|
||||
[](const Node& node,
|
||||
const torch::nativert::ExecutorConfig& /* executorConfig */) {
|
||||
return c10::starts_with(
|
||||
node.target(),
|
||||
"torch.ops.higher_order.executorch_call_delegate");
|
||||
},
|
||||
[](const Node& node,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::shared_ptr<Weights> weights,
|
||||
const torch::nativert::ExecutorConfig& executorConfig,
|
||||
caffe2::serialize::PyTorchStreamReader* packageReader)
|
||||
-> std::pair<
|
||||
KernelFactoryHandler::OpKernelPtr,
|
||||
KernelFactoryHandler::DelegateExecutorPtr> {
|
||||
auto delegateExecutor = std::make_unique<AOTIDelegateExecutor>(
|
||||
node,
|
||||
weights,
|
||||
executorConfig,
|
||||
packageReader,
|
||||
make_proxy_executor);
|
||||
|
||||
return {
|
||||
std::make_unique<ETCallDelegateKernel>(
|
||||
&node, *delegateExecutor),
|
||||
std::move(delegateExecutor)};
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user