mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aoti] Add initial custom op support (#127034)
Re-land of https://github.com/pytorch/pytorch/pull/125242 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127034 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
44fdf24967
commit
b90aa18569
@ -471,6 +471,7 @@ inductor_core_resources = [
|
||||
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
|
||||
"torch/csrc/inductor/inductor_ops.cpp",
|
||||
]
|
||||
|
||||
|
@ -1213,6 +1213,7 @@ if(BUILD_TEST)
|
||||
)
|
||||
else()
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
|
||||
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/tensorexpr
|
||||
${CMAKE_BINARY_DIR}/test_tensorexpr
|
||||
|
8
test/inductor/CMakeLists.txt
Normal file
8
test/inductor/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
||||
# Build separate libraries the define custom classes/operators used from our Python tests.
|
||||
# These are intended to be used with torch.ops.load_library() in our Python test suite.
|
||||
add_library(aoti_custom_ops SHARED custom_ops.cpp)
|
||||
target_link_libraries(aoti_custom_ops torch)
|
||||
|
||||
if(INSTALL_TEST)
|
||||
install(TARGETS aoti_custom_ops DESTINATION lib)
|
||||
endif()
|
364
test/inductor/custom_ops.cpp
Normal file
364
test/inductor/custom_ops.cpp
Normal file
@ -0,0 +1,364 @@
|
||||
#include <torch/csrc/api/include/torch/types.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace at {
|
||||
|
||||
Tensor custom_add_impl(Tensor t1, Tensor t2) {
|
||||
return t1 + t2;
|
||||
}
|
||||
|
||||
Tensor fn_with_all_inputs_impl(
|
||||
const Tensor& tensor,
|
||||
const c10::List<Tensor>& tensors,
|
||||
const c10::List<std::optional<Tensor>>& optional_tensors,
|
||||
const bool b8,
|
||||
const c10::List<bool>& b8s,
|
||||
const int64_t i64,
|
||||
const c10::List<int64_t>& i64s,
|
||||
const int64_t& symint,
|
||||
const IntArrayRef symints,
|
||||
const double f64,
|
||||
const c10::List<double>& f64s,
|
||||
const at::Scalar& scalar,
|
||||
at::ArrayRef<at::Scalar> scalars,
|
||||
const std::string& string,
|
||||
const std::vector<std::string>& strings,
|
||||
const c10::ScalarType& dtype,
|
||||
const MemoryFormat& memory_format,
|
||||
const Layout& layout,
|
||||
const Device& device,
|
||||
// optional
|
||||
const std::optional<Tensor>& o_tensor,
|
||||
const std::optional<c10::List<Tensor>>& o_tensors,
|
||||
const std::optional<bool>& o_b8,
|
||||
const std::optional<c10::List<bool>>& o_b8s,
|
||||
const std::optional<int64_t>& o_i64,
|
||||
const std::optional<c10::List<int64_t>>& o_i64s,
|
||||
const std::optional<int64_t>& o_symint,
|
||||
const std::optional<IntArrayRef>& o_symints,
|
||||
const std::optional<double>& o_f64,
|
||||
const std::optional<c10::List<double>>& o_f64s,
|
||||
const std::optional<at::Scalar>& o_scalar,
|
||||
const std::optional<at::ArrayRef<at::Scalar>>& o_scalars,
|
||||
const std::optional<std::string>& o_string,
|
||||
const std::optional<std::vector<std::string>>& o_strings,
|
||||
const std::optional<c10::ScalarType>& o_dtype,
|
||||
const std::optional<MemoryFormat>& o_memory_format,
|
||||
const std::optional<Layout>& o_layout,
|
||||
const std::optional<Device>& o_device) {
|
||||
std::cout << "tensor shape: " << tensor.sizes() << std::endl;
|
||||
|
||||
std::cout << "tensors shape: ";
|
||||
for (auto t : tensors) {
|
||||
std::cout << t.get().toTensor().sizes() << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "optional tensors shape: ";
|
||||
for (auto t : optional_tensors) {
|
||||
if (t.get().toOptional<Tensor>().has_value()) {
|
||||
std::cout << t.get().toTensor().sizes() << ", ";
|
||||
} else {
|
||||
std::cout << "None, ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "b8 " << c10::IValue(b8) << std::endl;
|
||||
std::cout << "b8s " << c10::IValue(b8s) << std::endl;
|
||||
std::cout << "i64 " << c10::IValue(i64) << std::endl;
|
||||
std::cout << "i64s " << c10::IValue(i64s) << std::endl;
|
||||
std::cout << "symint " << c10::IValue(symint) << std::endl;
|
||||
std::cout << "symints " << c10::IValue(symints) << std::endl;
|
||||
std::cout << "f64 " << c10::IValue(f64) << std::endl;
|
||||
std::cout << "f64s " << c10::IValue(f64s) << std::endl;
|
||||
std::cout << "scalar " << c10::IValue(scalar) << std::endl;
|
||||
std::cout << "scalars " << c10::IValue(scalars) << std::endl;
|
||||
std::cout << "string " << c10::IValue(string) << std::endl;
|
||||
std::cout << "strings " << c10::IValue(strings) << std::endl;
|
||||
std::cout << "dtype " << c10::IValue(dtype) << std::endl;
|
||||
std::cout << "memory_format " << c10::IValue(memory_format) << std::endl;
|
||||
std::cout << "layout " << c10::IValue(layout) << std::endl;
|
||||
std::cout << "device " << c10::IValue(device) << std::endl;
|
||||
|
||||
std::cout << "o_tensor "
|
||||
<< (o_tensor.has_value() ? c10::IValue(o_tensor.value().sizes())
|
||||
: "None")
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "o_tensors shape: ";
|
||||
if (o_tensors.has_value()) {
|
||||
for (auto t : o_tensors.value()) {
|
||||
std::cout << t.get().toTensor().sizes() << ", ";
|
||||
}
|
||||
} else {
|
||||
std::cout << "None";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "o_b8 "
|
||||
<< (o_b8.has_value() ? c10::IValue(o_b8.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_b8s "
|
||||
<< (o_b8s.has_value() ? c10::IValue(o_b8s.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_i64 "
|
||||
<< (o_i64.has_value() ? c10::IValue(o_i64.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_i64s "
|
||||
<< (o_i64s.has_value() ? c10::IValue(o_i64s.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_symint "
|
||||
<< (o_symint.has_value() ? c10::IValue(o_symint.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_symints "
|
||||
<< (o_symints.has_value() ? c10::IValue(o_symints.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_f64 "
|
||||
<< (o_f64.has_value() ? c10::IValue(o_f64.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_f64s "
|
||||
<< (o_f64s.has_value() ? c10::IValue(o_f64s.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_scalar "
|
||||
<< (o_scalar.has_value() ? c10::IValue(o_scalar.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_scalars "
|
||||
<< (o_scalars.has_value() ? c10::IValue(o_scalars.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_string "
|
||||
<< (o_string.has_value() ? c10::IValue(o_string.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_strings "
|
||||
<< (o_strings.has_value() ? c10::IValue(o_strings.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_dtype "
|
||||
<< (o_dtype.has_value() ? c10::IValue(o_dtype.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_memory_format "
|
||||
<< (o_memory_format.has_value()
|
||||
? c10::IValue(o_memory_format.value())
|
||||
: "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_layout "
|
||||
<< (o_layout.has_value() ? c10::IValue(o_layout.value()) : "None")
|
||||
<< std::endl;
|
||||
std::cout << "o_device "
|
||||
<< (o_device.has_value() ? c10::IValue(o_device.value()) : "None")
|
||||
<< std::endl;
|
||||
|
||||
int64_t int_hash = 0;
|
||||
int_hash ^= i64;
|
||||
for (auto i : i64s) {
|
||||
int_hash ^= i;
|
||||
}
|
||||
if (o_i64.has_value()) {
|
||||
int_hash ^= o_i64.value();
|
||||
}
|
||||
if (o_i64s.has_value()) {
|
||||
for (auto i : o_i64s.value()) {
|
||||
int_hash ^= i;
|
||||
}
|
||||
}
|
||||
|
||||
int_hash ^= symint;
|
||||
for (auto i : symints) {
|
||||
int_hash ^= i;
|
||||
}
|
||||
if (o_symint.has_value()) {
|
||||
int_hash ^= o_symint.value();
|
||||
}
|
||||
if (o_symints.has_value()) {
|
||||
for (auto i : o_symints.value()) {
|
||||
int_hash ^= i;
|
||||
}
|
||||
}
|
||||
|
||||
return tensor + int_hash;
|
||||
}
|
||||
|
||||
Tensor fn_with_default_input_impl(const Tensor& tensor, const int64_t i64) {
|
||||
return tensor + i64;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> fn_with_tuple_output_impl(
|
||||
const Tensor& tensor,
|
||||
const int64_t i64) {
|
||||
return {tensor + i64, tensor - i64};
|
||||
}
|
||||
|
||||
std::vector<Tensor> fn_with_list_output_impl(
|
||||
TensorList tensors,
|
||||
const int64_t i64) {
|
||||
std::vector<Tensor> outputs;
|
||||
for (auto& t : tensors) {
|
||||
outputs.emplace_back(t + i64);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::vector<Tensor>> fn_with_mix_outputs_impl(
|
||||
const Tensor& tensor,
|
||||
TensorList tensors) {
|
||||
std::vector<Tensor> outputs;
|
||||
for (auto& t : tensors) {
|
||||
outputs.emplace_back(t + 2);
|
||||
}
|
||||
return {tensor + 1, outputs};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> fn_with_input_mutation_impl(
|
||||
Tensor& t0,
|
||||
const Tensor& t1,
|
||||
Tensor& t2) {
|
||||
t0.add_(1);
|
||||
t2.sub_(1);
|
||||
return {t1 + 1, t1 + 2};
|
||||
}
|
||||
|
||||
// NOLINTBEGIN(clang-diagnostic-unused-parameter)
|
||||
Tensor fn_with_all_inputs_meta(
|
||||
const Tensor& tensor,
|
||||
const c10::List<Tensor>& tensors,
|
||||
const c10::List<std::optional<Tensor>>& optional_tensors,
|
||||
const bool b8,
|
||||
const c10::List<bool>& b8s,
|
||||
const int64_t i64,
|
||||
const c10::List<int64_t>& i64s,
|
||||
const c10::SymInt& symint,
|
||||
c10::SymIntArrayRef symints,
|
||||
const double f64,
|
||||
const c10::List<double>& f64s,
|
||||
const at::Scalar& scalar,
|
||||
at::ArrayRef<at::Scalar> scalars,
|
||||
const std::string& string,
|
||||
const std::vector<std::string>& strings,
|
||||
const c10::ScalarType& dtype,
|
||||
const MemoryFormat& memory_format,
|
||||
const Layout& layout,
|
||||
const Device& device,
|
||||
// optional
|
||||
const std::optional<Tensor>& o_tensor,
|
||||
const std::optional<c10::List<Tensor>>& o_tensors,
|
||||
const std::optional<bool>& o_b8,
|
||||
const std::optional<c10::List<bool>>& o_b8s,
|
||||
const std::optional<int64_t>& o_i64,
|
||||
const std::optional<c10::List<int64_t>>& o_i64s,
|
||||
const std::optional<c10::SymInt>& o_symint,
|
||||
at::OptionalSymIntArrayRef o_symints,
|
||||
const std::optional<double>& o_f64,
|
||||
const std::optional<c10::List<double>>& o_f64s,
|
||||
const std::optional<at::Scalar>& o_scalar,
|
||||
const std::optional<at::ArrayRef<at::Scalar>>& o_scalars,
|
||||
const std::optional<std::string>& o_string,
|
||||
const std::optional<std::vector<std::string>>& o_strings,
|
||||
const std::optional<c10::ScalarType>& o_dtype,
|
||||
const std::optional<MemoryFormat>& o_memory_format,
|
||||
const std::optional<Layout>& o_layout,
|
||||
const std::optional<Device>& o_device) {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
Tensor fn_with_default_input_meta(const Tensor& tensor, const int64_t i64) {
|
||||
return tensor.clone();
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> fn_with_tuple_output_meta(
|
||||
const Tensor& tensor,
|
||||
const int64_t i64) {
|
||||
return {tensor.clone(), tensor.clone()};
|
||||
}
|
||||
|
||||
std::vector<Tensor> fn_with_list_output_meta(
|
||||
TensorList tensors,
|
||||
const int64_t i64) {
|
||||
std::vector<Tensor> outputs;
|
||||
for (auto& t : tensors) {
|
||||
outputs.push_back(t.clone());
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::vector<Tensor>> fn_with_mix_outputs_meta(
|
||||
const Tensor& tensor,
|
||||
TensorList tensors) {
|
||||
std::vector<Tensor> outputs;
|
||||
for (auto& t : tensors) {
|
||||
outputs.push_back(t.clone());
|
||||
}
|
||||
return {tensor.clone(), outputs};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> fn_with_input_mutation_meta(
|
||||
Tensor& t0,
|
||||
const Tensor& t1,
|
||||
Tensor& t2) {
|
||||
return {t1.clone(), t1.clone()};
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
||||
TORCH_LIBRARY(aoti_custom_ops, m) {
|
||||
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
|
||||
m.def(
|
||||
"fn_with_all_inputs(Tensor tensor, "
|
||||
"Tensor[] tensors, "
|
||||
"Tensor?[] optional_tensors, "
|
||||
"bool b8, bool[] b8s, "
|
||||
"int i64, int[] i64s, "
|
||||
"SymInt symint, SymInt[] symints, "
|
||||
"float f64, float[] f64s, "
|
||||
"Scalar scalar, Scalar[] scalars, "
|
||||
"str string, str[] strings, "
|
||||
"ScalarType dtype, "
|
||||
"MemoryFormat memory_format, "
|
||||
"Layout layout, "
|
||||
"Device device, "
|
||||
"*, "
|
||||
"Tensor? o_tensor, Tensor[]? o_tensors, "
|
||||
"bool? o_b8, bool[]? o_b8s, "
|
||||
"int? o_i64, int[]? o_i64s, "
|
||||
"SymInt? o_symint, SymInt[]? o_symints, "
|
||||
"float? o_f64, float[]? o_f64s, "
|
||||
"Scalar? o_scalar, Scalar[]? o_scalars, "
|
||||
"str? o_string, str[]? o_strings, "
|
||||
"ScalarType? o_dtype, "
|
||||
"MemoryFormat? o_memory_format, "
|
||||
"Layout? o_layout, "
|
||||
"Device? o_device) -> Tensor");
|
||||
|
||||
m.def("fn_with_default_input(Tensor t, int i=3) -> Tensor");
|
||||
|
||||
m.def("fn_with_tuple_output(Tensor t, int i) -> (Tensor, Tensor)");
|
||||
|
||||
m.def("fn_with_list_output(Tensor[] tensors, int i) -> Tensor[]");
|
||||
|
||||
m.def(
|
||||
"fn_with_mix_outputs(Tensor t, Tensor[] tensors) -> (Tensor, Tensor[])");
|
||||
|
||||
m.def(
|
||||
"fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
|
||||
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
||||
m.impl("custom_add", at::custom_add_impl);
|
||||
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl);
|
||||
m.impl("fn_with_default_input", at::fn_with_default_input_impl);
|
||||
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl);
|
||||
m.impl("fn_with_list_output", at::fn_with_list_output_impl);
|
||||
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
|
||||
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
|
||||
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta);
|
||||
m.impl("fn_with_default_input", at::fn_with_default_input_meta);
|
||||
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta);
|
||||
m.impl("fn_with_list_output", at::fn_with_list_output_meta);
|
||||
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
|
||||
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
|
||||
}
|
@ -30,9 +30,11 @@ from torch.testing._internal.common_quantization import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
DeterministicGuard,
|
||||
find_library_location,
|
||||
IS_CI,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
skipIfRocm,
|
||||
TEST_WITH_ROCM,
|
||||
@ -2469,6 +2471,18 @@ class AOTInductorTestsTemplate:
|
||||
model.weight += 1
|
||||
self.check_model(model, example_inputs)
|
||||
|
||||
def test_custom_op_add(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aoti_custom_ops.custom_add(x, y)
|
||||
|
||||
m = M().to(device=self.device)
|
||||
args = (
|
||||
torch.randn(3, 3, device=self.device),
|
||||
torch.randn(3, 3, device=self.device),
|
||||
)
|
||||
self.check_model(m, args)
|
||||
|
||||
def test_triton_kernel_extern_kernel_arg(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
@ -3040,7 +3054,21 @@ class AOTInductorTestsTemplate:
|
||||
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
|
||||
|
||||
|
||||
class AOTInductorTestABICompatibleCpu(TestCase):
|
||||
class AOTITestCase(TestCase):
|
||||
def setUp(self):
|
||||
if IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library("//caffe2/test/inductor:custom_ops")
|
||||
elif IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
else:
|
||||
lib_file_path = find_library_location("libaoti_custom_ops.so")
|
||||
if IS_WINDOWS:
|
||||
lib_file_path = find_library_location("aoti_custom_ops.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
super().setUp()
|
||||
|
||||
|
||||
class AOTInductorTestABICompatibleCpu(AOTITestCase):
|
||||
device = "cpu"
|
||||
abi_compatible = True
|
||||
check_model = check_model
|
||||
@ -3199,6 +3227,7 @@ CPU_TEST_FAILURES = {
|
||||
"test_while_loop_with_outer_buffers": fail_stack_allocation(is_skip=True),
|
||||
# TODO: use of undeclared identifier 'float8_e4m3fn' and 'half'
|
||||
"test_fp8": fail_minimal_arrayref_interface(is_skip=True),
|
||||
"test_custom_op_add": fail_minimal_arrayref_interface(is_skip=True),
|
||||
}
|
||||
|
||||
# test_failures, xfail by default, set is_skip=True to skip
|
||||
@ -3213,6 +3242,7 @@ CUDA_TEST_FAILURES = {
|
||||
"test_runtime_checks_shape_failed": fail_non_abi_compatible_cuda(is_skip=True),
|
||||
# quantized unsupported for GPU
|
||||
"test_quantized_linear": fail_cuda(is_skip=True),
|
||||
"test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True),
|
||||
}
|
||||
|
||||
|
||||
@ -3268,7 +3298,7 @@ copy_tests(
|
||||
)
|
||||
|
||||
|
||||
class AOTInductorTestABICompatibleCpuWithStackAllocation(TestCase):
|
||||
class AOTInductorTestABICompatibleCpuWithStackAllocation(AOTITestCase):
|
||||
device = "cpu"
|
||||
abi_compatible = True
|
||||
check_model = check_model
|
||||
@ -3306,7 +3336,7 @@ copy_tests(
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
||||
class AOTInductorTestABICompatibleCuda(TestCase):
|
||||
class AOTInductorTestABICompatibleCuda(AOTITestCase):
|
||||
device = "cuda"
|
||||
abi_compatible = True
|
||||
check_model = check_model
|
||||
@ -3328,7 +3358,7 @@ copy_tests(
|
||||
IS_FBCODE or sys.platform == "darwin",
|
||||
"NonABI mode should not be used in fbcode nor on MacOS",
|
||||
)
|
||||
class AOTInductorTestNonABICompatibleCpu(TestCase):
|
||||
class AOTInductorTestNonABICompatibleCpu(AOTITestCase):
|
||||
device = "cpu"
|
||||
abi_compatible = False
|
||||
check_model = check_model
|
||||
@ -3355,6 +3385,7 @@ copy_tests(
|
||||
"test_runtime_checks_shape_failed": TestFailure(
|
||||
("non_abi_compatible_cpu",), is_skip=True
|
||||
),
|
||||
"test_custom_op_add": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
|
||||
},
|
||||
)
|
||||
|
||||
@ -3363,7 +3394,7 @@ copy_tests(
|
||||
IS_FBCODE or sys.platform == "darwin",
|
||||
"NonABI mode should not be used in fbcode nor on MacOS",
|
||||
)
|
||||
class AOTInductorTestNonABICompatibleCuda(TestCase):
|
||||
class AOTInductorTestNonABICompatibleCuda(AOTITestCase):
|
||||
device = "cuda"
|
||||
abi_compatible = False
|
||||
check_model = check_model
|
||||
|
@ -47,6 +47,15 @@ class AOTIRunnerUtil:
|
||||
restore_fqn=False,
|
||||
)
|
||||
|
||||
if IS_FBCODE:
|
||||
from deeplearning.aot_inductor.extern_node_thrift_serializer import (
|
||||
thrift_serializer,
|
||||
)
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
options["extern_node_serializer"] = thrift_serializer
|
||||
|
||||
with torch.no_grad():
|
||||
so_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type]
|
||||
|
||||
|
18
third_party/nlohmann.BUILD
vendored
18
third_party/nlohmann.BUILD
vendored
@ -1,20 +1,18 @@
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||
|
||||
cc_library(
|
||||
name = "nlohmann",
|
||||
hdrs = glob(["include/**/*.hpp"]),
|
||||
includes = [
|
||||
"/",
|
||||
],
|
||||
cc_library(name = "nlohmann",
|
||||
includes = ["include"],
|
||||
deps = ["nlohmann-internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_import(name = "nlohmann-internal",
|
||||
hdrs = glob(["include/**/*.hpp"]),
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nlohmann_single_include",
|
||||
hdrs = glob(["single_include/nlohmann/*.hpp"]),
|
||||
includes = [
|
||||
"/",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
15
torch/_export/serde/aoti_schema.py
Normal file
15
torch/_export/serde/aoti_schema.py
Normal file
@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from torch._export.serde.schema import Node
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternKernelNode:
|
||||
name: str
|
||||
node: Node
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternKernelNodes:
|
||||
nodes: List[ExternKernelNode]
|
@ -1797,7 +1797,7 @@ class AotCodeCompiler:
|
||||
with lock:
|
||||
# Currently, this only support serializing extern nodes in fbcode
|
||||
# Eventually, we should also have a serializer for OSS.
|
||||
if config.is_fbcode() and serialized_extern_kernel_nodes:
|
||||
if serialized_extern_kernel_nodes:
|
||||
output_json = os.path.splitext(input_path)[0] + ".json"
|
||||
with open(output_json, "w") as f:
|
||||
f.write(serialized_extern_kernel_nodes)
|
||||
|
@ -2111,19 +2111,19 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
if isinstance(output_args, str):
|
||||
output_args = [output_args]
|
||||
|
||||
if config.is_fbcode():
|
||||
if V.graph.aot_mode and config.abi_compatible:
|
||||
assert op_overload is not None
|
||||
assert raw_args is not None
|
||||
assert outputs is not None
|
||||
|
||||
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
|
||||
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor(
|
||||
cpp_kernel_key,
|
||||
op_overload,
|
||||
raw_args,
|
||||
output_args,
|
||||
)
|
||||
else:
|
||||
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
|
||||
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
|
||||
buf_name,
|
||||
python_kernel_name,
|
||||
cpp_kernel_name,
|
||||
@ -2207,7 +2207,7 @@ if (custom_op_wrapper.get() == NULL) {
|
||||
lines += f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(raw_arg, arg_type)});\n"
|
||||
return lines
|
||||
|
||||
def generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
|
||||
def generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
|
||||
self,
|
||||
buf_name: str,
|
||||
python_kernel_name: str,
|
||||
@ -2220,7 +2220,7 @@ if (custom_op_wrapper.get() == NULL) {
|
||||
raw_args=None,
|
||||
output_args: Optional[List[str]] = None,
|
||||
):
|
||||
if V.graph.aot_mode or not config.abi_compatible:
|
||||
if not config.abi_compatible:
|
||||
# Will update this to use an OSS version ProxyExecutor
|
||||
if cpp_kernel_key not in self.extern_call_ops:
|
||||
self.writeline(
|
||||
@ -2288,7 +2288,7 @@ if (py_{buf_name}.get() == NULL) {{
|
||||
)
|
||||
self.writelines(scope_gil_acquire)
|
||||
|
||||
def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
|
||||
def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor(
|
||||
self,
|
||||
cpp_kernel_key,
|
||||
op_overload,
|
||||
|
26
torch/_inductor/extern_node_serializer.py
Normal file
26
torch/_inductor/extern_node_serializer.py
Normal file
@ -0,0 +1,26 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
|
||||
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
|
||||
|
||||
from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode
|
||||
|
||||
|
||||
def serialize_extern_kernel_node(
|
||||
extern_kernel_node: inductor_ExternKernelNode,
|
||||
) -> ExternKernelNode:
|
||||
assert isinstance(extern_kernel_node.node, Node)
|
||||
return ExternKernelNode(
|
||||
name=extern_kernel_node.name,
|
||||
node=extern_kernel_node.node,
|
||||
)
|
||||
|
||||
|
||||
def extern_node_json_serializer(
|
||||
extern_kernel_nodes: List[inductor_ExternKernelNode],
|
||||
) -> str:
|
||||
serialized_nodes = ExternKernelNodes(
|
||||
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
|
||||
)
|
||||
return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder)
|
@ -380,9 +380,15 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
|
||||
# See `ProxyExecutor Design Note` in ir.py for more details
|
||||
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
|
||||
self.extern_node_serializer: Optional[
|
||||
Callable[[List[ir.ExternKernelNode]], Any]
|
||||
] = extern_node_serializer
|
||||
|
||||
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
||||
|
||||
self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = (
|
||||
extern_node_serializer
|
||||
if config.is_fbcode() and extern_node_serializer
|
||||
else extern_node_json_serializer
|
||||
)
|
||||
|
||||
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
||||
self.lists: Dict[str, List[str]] = {}
|
||||
self.mutated_inputs: Set[str] = set()
|
||||
@ -1808,11 +1814,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
output_code_log.debug("Output code: \n%s", code)
|
||||
|
||||
serialized_extern_kernel_nodes = None
|
||||
if (
|
||||
config.is_fbcode()
|
||||
and self.extern_kernel_nodes
|
||||
and self.extern_node_serializer
|
||||
):
|
||||
if self.extern_kernel_nodes:
|
||||
serialized_extern_kernel_nodes = self.extern_node_serializer(
|
||||
self.extern_kernel_nodes
|
||||
)
|
||||
|
@ -2,8 +2,18 @@
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
|
||||
// TODO: Investigate why this is necessary, but fixes build problems in FRL
|
||||
#if __has_include("filesystem")
|
||||
#include <filesystem>
|
||||
namespace fs = std::filesystem;
|
||||
#else
|
||||
#include <experimental/filesystem>
|
||||
namespace fs = std::experimental::filesystem;
|
||||
#endif
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
AOTIModelContainerRunner::AOTIModelContainerRunner(
|
||||
@ -46,6 +56,17 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
|
||||
get_call_spec_func_ = reinterpret_cast<decltype(get_call_spec_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetCallSpec"));
|
||||
|
||||
// Hack to find the json file name from the model so file
|
||||
size_t lastindex = model_so_path.find_last_of(".");
|
||||
std::string json_filename = model_so_path.substr(0, lastindex) + ".json";
|
||||
|
||||
if (fs::exists(json_filename)) {
|
||||
proxy_executor_ = std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
|
||||
json_filename, device_str == "cpu");
|
||||
proxy_executor_handle_ =
|
||||
reinterpret_cast<AOTIProxyExecutorHandle>(proxy_executor_.get());
|
||||
}
|
||||
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(create_func_(
|
||||
&container_handle_,
|
||||
num_models,
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/interface.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
|
||||
|
||||
// Forward declare DynamicLibrary
|
||||
namespace at {
|
||||
@ -75,9 +76,10 @@ class TORCH_API AOTIModelContainerRunner {
|
||||
|
||||
AOTInductorModelContainerHandle container_handle_ = nullptr;
|
||||
|
||||
// TODO: need an OSS proxy executor implementation. For now,
|
||||
// proxy_executor_handle_ will always be nullptr.
|
||||
AOTIProxyExecutorHandle proxy_executor_handle_ = nullptr;
|
||||
AOTIProxyExecutorHandle proxy_executor_handle_;
|
||||
|
||||
private:
|
||||
std::unique_ptr<torch::aot_inductor::ProxyExecutor> proxy_executor_;
|
||||
};
|
||||
|
||||
} // namespace torch::inductor
|
||||
|
262
torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp
Normal file
262
torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp
Normal file
@ -0,0 +1,262 @@
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
|
||||
|
||||
namespace {
|
||||
at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) {
|
||||
return reinterpret_cast<at::Tensor*>(handle);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
void OSSProxyExecutor::prefill_stack_with_static_arguments(
|
||||
int index,
|
||||
at::TypePtr schema_arg_type,
|
||||
const nlohmann::json& serialized_arg,
|
||||
OpKernel& op_kernel) {
|
||||
auto& stack = op_kernel.stack_;
|
||||
auto& dynamic_args = op_kernel.dynamic_args_;
|
||||
|
||||
TORCH_CHECK(serialized_arg.size() == 1);
|
||||
std::string serialized_arg_type = serialized_arg.begin().key();
|
||||
auto& serialized_arg_val = serialized_arg.begin().value();
|
||||
|
||||
switch (schema_arg_type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
TORCH_CHECK(serialized_arg_type == "as_tensor");
|
||||
stack.emplace_back();
|
||||
dynamic_args.emplace_back(
|
||||
index, DynamicArgType::TensorType, 1, std::move(serialized_arg_val));
|
||||
break;
|
||||
}
|
||||
// TODO: handle the other input types
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type);
|
||||
}
|
||||
}
|
||||
|
||||
// Populates op_kernel.stack_, op_kernel.dynamic_args_
|
||||
void OSSProxyExecutor::get_input_info_from_serialized(
|
||||
const std::vector<c10::Argument>& schema_args,
|
||||
const nlohmann::json& serialized_node,
|
||||
OpKernel& op_kernel) {
|
||||
int index = 0;
|
||||
for (const auto& named_argument : serialized_node["inputs"]) {
|
||||
const auto& arg = named_argument["arg"];
|
||||
auto& schema_arg = schema_args[index];
|
||||
|
||||
prefill_stack_with_static_arguments(
|
||||
index++, schema_arg.real_type(), arg, op_kernel);
|
||||
}
|
||||
|
||||
// TODO: prefill default values
|
||||
}
|
||||
|
||||
// Populates op_kernel.outputs_
|
||||
void OSSProxyExecutor::get_output_info_from_serialized(
|
||||
const std::vector<c10::Argument>& schema_returns,
|
||||
const nlohmann::json& serialized_node,
|
||||
OpKernel& op_kernel) {
|
||||
std::vector<DynamicArg>& outputs = op_kernel.outputs_;
|
||||
|
||||
TORCH_CHECK(
|
||||
schema_returns.size() == serialized_node["outputs"].size(),
|
||||
"Serialized node doesn't match op's schema outputs.");
|
||||
|
||||
size_t output_index = 0;
|
||||
for (const auto& serialized_output : serialized_node["outputs"]) {
|
||||
TORCH_CHECK(serialized_output.size() == 1);
|
||||
std::string serialized_output_type = serialized_output.begin().key();
|
||||
auto& serialized_output_val = serialized_output.begin().value();
|
||||
|
||||
auto& schema_return = schema_returns[output_index];
|
||||
at::TypePtr schema_return_type = schema_return.real_type();
|
||||
|
||||
switch (schema_return_type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
TORCH_CHECK(
|
||||
serialized_output_type == "as_tensor",
|
||||
serialized_node["target"],
|
||||
" got serialized_output_type of ",
|
||||
serialized_output_type);
|
||||
outputs.emplace_back(
|
||||
output_index,
|
||||
DynamicArgType::TensorType,
|
||||
1,
|
||||
serialized_output_type);
|
||||
break;
|
||||
}
|
||||
case c10::TypeKind::ListType: {
|
||||
if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) {
|
||||
TORCH_CHECK(
|
||||
serialized_output_type == "as_tensors",
|
||||
serialized_node["target"],
|
||||
" got serialized_output_type of ",
|
||||
serialized_output_type);
|
||||
outputs.emplace_back(
|
||||
output_index,
|
||||
DynamicArgType::ListTensorType,
|
||||
serialized_output_val.size(),
|
||||
serialized_output_type);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported return list type ",
|
||||
schema_return_type->repr_str());
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported return type ", schema_return_type->repr_str());
|
||||
}
|
||||
}
|
||||
|
||||
output_index++;
|
||||
}
|
||||
}
|
||||
|
||||
OSSProxyExecutor::OSSProxyExecutor(const std::string& json_path, bool is_cpu) {
|
||||
if (is_cpu) {
|
||||
device_ = std::make_unique<c10::Device>(c10::DeviceType::CPU);
|
||||
} else {
|
||||
int device_idx = -1;
|
||||
device_ = std::make_unique<c10::Device>(c10::DeviceType::CUDA, device_idx);
|
||||
}
|
||||
|
||||
std::string extern_kernel_nodes_serialized;
|
||||
|
||||
std::ifstream json_file(json_path);
|
||||
TORCH_CHECK(json_file.is_open());
|
||||
|
||||
// Parse file into a json object
|
||||
nlohmann::json json_obj;
|
||||
json_file >> json_obj;
|
||||
|
||||
// Access data
|
||||
for (auto const& serialized_extern_node : json_obj["nodes"]) {
|
||||
auto const& serialized_node = serialized_extern_node["node"];
|
||||
|
||||
const std::string& target = serialized_node["target"];
|
||||
|
||||
std::string opName;
|
||||
std::string overloadName;
|
||||
size_t pos = target.find('.');
|
||||
if (pos == std::string::npos) {
|
||||
opName = target;
|
||||
overloadName = "";
|
||||
} else {
|
||||
// There should be no more periods
|
||||
size_t pos2 = target.find('.', pos);
|
||||
TORCH_CHECK(pos2 == std::string::npos);
|
||||
|
||||
opName = target.substr(0, pos);
|
||||
overloadName = target.substr(pos + 1, target.length() - pos);
|
||||
}
|
||||
|
||||
c10::OperatorHandle op_handle =
|
||||
c10::Dispatcher::singleton().findSchemaOrThrow(
|
||||
opName.c_str(), overloadName.c_str());
|
||||
const c10::FunctionSchema& schema = op_handle.schema();
|
||||
|
||||
const auto& schema_args = schema.arguments();
|
||||
const auto& schema_returns = schema.returns();
|
||||
|
||||
OpKernel op_kernel(target, op_handle);
|
||||
get_input_info_from_serialized(schema_args, serialized_node, op_kernel);
|
||||
get_output_info_from_serialized(schema_returns, serialized_node, op_kernel);
|
||||
|
||||
op_kernels_.emplace_back(std::move(op_kernel));
|
||||
}
|
||||
}
|
||||
|
||||
void OSSProxyExecutor::call_function(
|
||||
int extern_node_index,
|
||||
int num_ints,
|
||||
int64_t* flatten_int_args,
|
||||
int num_tensors,
|
||||
AtenTensorHandle* flatten_tensor_args) {
|
||||
TORCH_CHECK(
|
||||
extern_node_index < static_cast<int>(op_kernels_.size()),
|
||||
"Invalid extern node index");
|
||||
OpKernel& op_kernel = op_kernels_[extern_node_index];
|
||||
|
||||
std::vector<c10::IValue> stack = op_kernel.stack_;
|
||||
auto& dynamic_args = op_kernel.dynamic_args_;
|
||||
|
||||
int tensor_id = 0;
|
||||
int int_id = 0;
|
||||
for (auto& dynamic_arg : dynamic_args) {
|
||||
int arg_index = dynamic_arg.arg_index;
|
||||
DynamicArgType dynamic_arg_type = dynamic_arg.arg_type;
|
||||
int length = dynamic_arg.length;
|
||||
|
||||
if (length == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (dynamic_arg_type) {
|
||||
case DynamicArgType::TensorType: {
|
||||
at::Tensor* tensor =
|
||||
tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
|
||||
stack[arg_index] = *tensor;
|
||||
break;
|
||||
}
|
||||
// TODO: handle other dynamic arg types
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type);
|
||||
}
|
||||
}
|
||||
|
||||
int num_output_tensors = op_kernel.num_output_tensors();
|
||||
TORCH_CHECK(
|
||||
tensor_id == num_tensors - num_output_tensors,
|
||||
"Mismatch between tensors consumed and num of input tensor, got tensor_id = .",
|
||||
tensor_id,
|
||||
", expected num = ",
|
||||
num_tensors - num_output_tensors);
|
||||
TORCH_CHECK(
|
||||
int_id == num_ints,
|
||||
"Mismatch between ints consumed and num_ints, got int_id = ",
|
||||
int_id,
|
||||
", num_ints = ",
|
||||
num_ints);
|
||||
|
||||
// Call the op with the prepared stack.
|
||||
const c10::OperatorHandle& op = op_kernel.op_handle_;
|
||||
op.callBoxed(stack);
|
||||
|
||||
const c10::FunctionSchema& schema = op.schema();
|
||||
const auto& schema_returns = schema.returns();
|
||||
|
||||
TORCH_CHECK(op_kernel.outputs_.size() == stack.size());
|
||||
// TODO: what about optional outputs? This assert may not hold
|
||||
TORCH_CHECK(stack.size() == schema_returns.size());
|
||||
|
||||
int index = 0;
|
||||
for (const auto& schema_return : schema_returns) {
|
||||
if (schema_return.type()->kind() == c10::TypeKind::TensorType) {
|
||||
at::Tensor* tensor =
|
||||
tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
|
||||
*tensor = stack[index++].toTensor();
|
||||
// TODO: handle tensor list returns
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"NYI: Unsupported return type for schema: ",
|
||||
schema_return.type()->repr_str());
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
tensor_id == num_tensors,
|
||||
"Mismatch between tensors consumed and num_tensors, got tensor_id = ",
|
||||
tensor_id,
|
||||
", expected num = ",
|
||||
num_tensors);
|
||||
}
|
||||
|
||||
} // namespace torch::aot_inductor
|
101
torch/csrc/inductor/aoti_torch/oss_proxy_executor.h
Normal file
101
torch/csrc/inductor/aoti_torch/oss_proxy_executor.h
Normal file
@ -0,0 +1,101 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
enum class DynamicArgType : int {
|
||||
TensorType = 0,
|
||||
ListTensorType = 1,
|
||||
ListOptionalTensorType = 2,
|
||||
IntType = 3,
|
||||
ListIntType = 4,
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) {
|
||||
os << static_cast<int>(arg_type);
|
||||
return os;
|
||||
}
|
||||
|
||||
inline bool isTensorType(DynamicArgType arg_type) {
|
||||
return arg_type == DynamicArgType::TensorType ||
|
||||
arg_type == DynamicArgType::ListTensorType ||
|
||||
arg_type == DynamicArgType::ListOptionalTensorType;
|
||||
}
|
||||
|
||||
struct DynamicArg {
|
||||
DynamicArg(
|
||||
int arg_index,
|
||||
DynamicArgType arg_type,
|
||||
int length,
|
||||
nlohmann::json serialized_arg_val)
|
||||
: arg_index(arg_index),
|
||||
arg_type(arg_type),
|
||||
length(length),
|
||||
serialized_arg_val(std::move(serialized_arg_val)) {}
|
||||
int arg_index;
|
||||
DynamicArgType arg_type;
|
||||
int length;
|
||||
nlohmann::json serialized_arg_val;
|
||||
};
|
||||
|
||||
struct OpKernel {
|
||||
OpKernel(const std::string& target, const c10::OperatorHandle& op_handle)
|
||||
: target_(target), op_handle_(op_handle) {}
|
||||
|
||||
std::string target_;
|
||||
c10::OperatorHandle op_handle_;
|
||||
std::vector<DynamicArg> dynamic_args_;
|
||||
std::vector<DynamicArg> outputs_;
|
||||
std::vector<c10::IValue> stack_;
|
||||
|
||||
int num_output_tensors() const {
|
||||
int num_output_tensors = 0;
|
||||
for (const auto& output : outputs_) {
|
||||
if (isTensorType(output.arg_type)) {
|
||||
num_output_tensors += output.length;
|
||||
}
|
||||
}
|
||||
return num_output_tensors;
|
||||
}
|
||||
};
|
||||
|
||||
class OSSProxyExecutor : public ProxyExecutor {
|
||||
public:
|
||||
explicit OSSProxyExecutor(const std::string& json_path, bool is_cpu);
|
||||
|
||||
void call_function(
|
||||
int extern_node_index,
|
||||
int num_ints,
|
||||
int64_t* flatten_int_args,
|
||||
int num_tensors,
|
||||
AtenTensorHandle* flatten_tensor_args) override;
|
||||
|
||||
private:
|
||||
void prefill_stack_with_static_arguments(
|
||||
int index,
|
||||
at::TypePtr schema_arg_type,
|
||||
const nlohmann::json& thrift_arg,
|
||||
OpKernel& op_kernel);
|
||||
|
||||
void get_input_info_from_serialized(
|
||||
const std::vector<c10::Argument>& schema_args,
|
||||
const nlohmann::json& serialized_node,
|
||||
OpKernel& op_kernel);
|
||||
|
||||
void get_output_info_from_serialized(
|
||||
const std::vector<c10::Argument>& schema_returns,
|
||||
const nlohmann::json& serialized_node,
|
||||
OpKernel& op_kernel);
|
||||
|
||||
std::vector<OpKernel> op_kernels_;
|
||||
std::unique_ptr<c10::Device> device_;
|
||||
};
|
||||
|
||||
} // namespace torch::aot_inductor
|
@ -5,6 +5,7 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
Reference in New Issue
Block a user