[aoti] Add initial custom op support (#127034)

Re-land of https://github.com/pytorch/pytorch/pull/125242

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127034
Approved by: https://github.com/malfet
This commit is contained in:
angelayi
2024-07-24 20:29:54 +00:00
committed by PyTorch MergeBot
parent 44fdf24967
commit b90aa18569
17 changed files with 875 additions and 33 deletions

View File

@ -471,6 +471,7 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp",
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
]

View File

@ -1213,6 +1213,7 @@ if(BUILD_TEST)
)
else()
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
add_subdirectory(
${TORCH_ROOT}/test/cpp/tensorexpr
${CMAKE_BINARY_DIR}/test_tensorexpr

View File

@ -0,0 +1,8 @@
# Build separate libraries the define custom classes/operators used from our Python tests.
# These are intended to be used with torch.ops.load_library() in our Python test suite.
add_library(aoti_custom_ops SHARED custom_ops.cpp)
target_link_libraries(aoti_custom_ops torch)
if(INSTALL_TEST)
install(TARGETS aoti_custom_ops DESTINATION lib)
endif()

View File

@ -0,0 +1,364 @@
#include <torch/csrc/api/include/torch/types.h>
#include <cstdint>
#include <iostream>
#include <string>
namespace at {
Tensor custom_add_impl(Tensor t1, Tensor t2) {
return t1 + t2;
}
Tensor fn_with_all_inputs_impl(
const Tensor& tensor,
const c10::List<Tensor>& tensors,
const c10::List<std::optional<Tensor>>& optional_tensors,
const bool b8,
const c10::List<bool>& b8s,
const int64_t i64,
const c10::List<int64_t>& i64s,
const int64_t& symint,
const IntArrayRef symints,
const double f64,
const c10::List<double>& f64s,
const at::Scalar& scalar,
at::ArrayRef<at::Scalar> scalars,
const std::string& string,
const std::vector<std::string>& strings,
const c10::ScalarType& dtype,
const MemoryFormat& memory_format,
const Layout& layout,
const Device& device,
// optional
const std::optional<Tensor>& o_tensor,
const std::optional<c10::List<Tensor>>& o_tensors,
const std::optional<bool>& o_b8,
const std::optional<c10::List<bool>>& o_b8s,
const std::optional<int64_t>& o_i64,
const std::optional<c10::List<int64_t>>& o_i64s,
const std::optional<int64_t>& o_symint,
const std::optional<IntArrayRef>& o_symints,
const std::optional<double>& o_f64,
const std::optional<c10::List<double>>& o_f64s,
const std::optional<at::Scalar>& o_scalar,
const std::optional<at::ArrayRef<at::Scalar>>& o_scalars,
const std::optional<std::string>& o_string,
const std::optional<std::vector<std::string>>& o_strings,
const std::optional<c10::ScalarType>& o_dtype,
const std::optional<MemoryFormat>& o_memory_format,
const std::optional<Layout>& o_layout,
const std::optional<Device>& o_device) {
std::cout << "tensor shape: " << tensor.sizes() << std::endl;
std::cout << "tensors shape: ";
for (auto t : tensors) {
std::cout << t.get().toTensor().sizes() << ", ";
}
std::cout << std::endl;
std::cout << "optional tensors shape: ";
for (auto t : optional_tensors) {
if (t.get().toOptional<Tensor>().has_value()) {
std::cout << t.get().toTensor().sizes() << ", ";
} else {
std::cout << "None, ";
}
}
std::cout << std::endl;
std::cout << "b8 " << c10::IValue(b8) << std::endl;
std::cout << "b8s " << c10::IValue(b8s) << std::endl;
std::cout << "i64 " << c10::IValue(i64) << std::endl;
std::cout << "i64s " << c10::IValue(i64s) << std::endl;
std::cout << "symint " << c10::IValue(symint) << std::endl;
std::cout << "symints " << c10::IValue(symints) << std::endl;
std::cout << "f64 " << c10::IValue(f64) << std::endl;
std::cout << "f64s " << c10::IValue(f64s) << std::endl;
std::cout << "scalar " << c10::IValue(scalar) << std::endl;
std::cout << "scalars " << c10::IValue(scalars) << std::endl;
std::cout << "string " << c10::IValue(string) << std::endl;
std::cout << "strings " << c10::IValue(strings) << std::endl;
std::cout << "dtype " << c10::IValue(dtype) << std::endl;
std::cout << "memory_format " << c10::IValue(memory_format) << std::endl;
std::cout << "layout " << c10::IValue(layout) << std::endl;
std::cout << "device " << c10::IValue(device) << std::endl;
std::cout << "o_tensor "
<< (o_tensor.has_value() ? c10::IValue(o_tensor.value().sizes())
: "None")
<< std::endl;
std::cout << "o_tensors shape: ";
if (o_tensors.has_value()) {
for (auto t : o_tensors.value()) {
std::cout << t.get().toTensor().sizes() << ", ";
}
} else {
std::cout << "None";
}
std::cout << std::endl;
std::cout << "o_b8 "
<< (o_b8.has_value() ? c10::IValue(o_b8.value()) : "None")
<< std::endl;
std::cout << "o_b8s "
<< (o_b8s.has_value() ? c10::IValue(o_b8s.value()) : "None")
<< std::endl;
std::cout << "o_i64 "
<< (o_i64.has_value() ? c10::IValue(o_i64.value()) : "None")
<< std::endl;
std::cout << "o_i64s "
<< (o_i64s.has_value() ? c10::IValue(o_i64s.value()) : "None")
<< std::endl;
std::cout << "o_symint "
<< (o_symint.has_value() ? c10::IValue(o_symint.value()) : "None")
<< std::endl;
std::cout << "o_symints "
<< (o_symints.has_value() ? c10::IValue(o_symints.value()) : "None")
<< std::endl;
std::cout << "o_f64 "
<< (o_f64.has_value() ? c10::IValue(o_f64.value()) : "None")
<< std::endl;
std::cout << "o_f64s "
<< (o_f64s.has_value() ? c10::IValue(o_f64s.value()) : "None")
<< std::endl;
std::cout << "o_scalar "
<< (o_scalar.has_value() ? c10::IValue(o_scalar.value()) : "None")
<< std::endl;
std::cout << "o_scalars "
<< (o_scalars.has_value() ? c10::IValue(o_scalars.value()) : "None")
<< std::endl;
std::cout << "o_string "
<< (o_string.has_value() ? c10::IValue(o_string.value()) : "None")
<< std::endl;
std::cout << "o_strings "
<< (o_strings.has_value() ? c10::IValue(o_strings.value()) : "None")
<< std::endl;
std::cout << "o_dtype "
<< (o_dtype.has_value() ? c10::IValue(o_dtype.value()) : "None")
<< std::endl;
std::cout << "o_memory_format "
<< (o_memory_format.has_value()
? c10::IValue(o_memory_format.value())
: "None")
<< std::endl;
std::cout << "o_layout "
<< (o_layout.has_value() ? c10::IValue(o_layout.value()) : "None")
<< std::endl;
std::cout << "o_device "
<< (o_device.has_value() ? c10::IValue(o_device.value()) : "None")
<< std::endl;
int64_t int_hash = 0;
int_hash ^= i64;
for (auto i : i64s) {
int_hash ^= i;
}
if (o_i64.has_value()) {
int_hash ^= o_i64.value();
}
if (o_i64s.has_value()) {
for (auto i : o_i64s.value()) {
int_hash ^= i;
}
}
int_hash ^= symint;
for (auto i : symints) {
int_hash ^= i;
}
if (o_symint.has_value()) {
int_hash ^= o_symint.value();
}
if (o_symints.has_value()) {
for (auto i : o_symints.value()) {
int_hash ^= i;
}
}
return tensor + int_hash;
}
Tensor fn_with_default_input_impl(const Tensor& tensor, const int64_t i64) {
return tensor + i64;
}
std::tuple<Tensor, Tensor> fn_with_tuple_output_impl(
const Tensor& tensor,
const int64_t i64) {
return {tensor + i64, tensor - i64};
}
std::vector<Tensor> fn_with_list_output_impl(
TensorList tensors,
const int64_t i64) {
std::vector<Tensor> outputs;
for (auto& t : tensors) {
outputs.emplace_back(t + i64);
}
return outputs;
}
std::tuple<Tensor, std::vector<Tensor>> fn_with_mix_outputs_impl(
const Tensor& tensor,
TensorList tensors) {
std::vector<Tensor> outputs;
for (auto& t : tensors) {
outputs.emplace_back(t + 2);
}
return {tensor + 1, outputs};
}
std::tuple<Tensor, Tensor> fn_with_input_mutation_impl(
Tensor& t0,
const Tensor& t1,
Tensor& t2) {
t0.add_(1);
t2.sub_(1);
return {t1 + 1, t1 + 2};
}
// NOLINTBEGIN(clang-diagnostic-unused-parameter)
Tensor fn_with_all_inputs_meta(
const Tensor& tensor,
const c10::List<Tensor>& tensors,
const c10::List<std::optional<Tensor>>& optional_tensors,
const bool b8,
const c10::List<bool>& b8s,
const int64_t i64,
const c10::List<int64_t>& i64s,
const c10::SymInt& symint,
c10::SymIntArrayRef symints,
const double f64,
const c10::List<double>& f64s,
const at::Scalar& scalar,
at::ArrayRef<at::Scalar> scalars,
const std::string& string,
const std::vector<std::string>& strings,
const c10::ScalarType& dtype,
const MemoryFormat& memory_format,
const Layout& layout,
const Device& device,
// optional
const std::optional<Tensor>& o_tensor,
const std::optional<c10::List<Tensor>>& o_tensors,
const std::optional<bool>& o_b8,
const std::optional<c10::List<bool>>& o_b8s,
const std::optional<int64_t>& o_i64,
const std::optional<c10::List<int64_t>>& o_i64s,
const std::optional<c10::SymInt>& o_symint,
at::OptionalSymIntArrayRef o_symints,
const std::optional<double>& o_f64,
const std::optional<c10::List<double>>& o_f64s,
const std::optional<at::Scalar>& o_scalar,
const std::optional<at::ArrayRef<at::Scalar>>& o_scalars,
const std::optional<std::string>& o_string,
const std::optional<std::vector<std::string>>& o_strings,
const std::optional<c10::ScalarType>& o_dtype,
const std::optional<MemoryFormat>& o_memory_format,
const std::optional<Layout>& o_layout,
const std::optional<Device>& o_device) {
return tensor;
}
Tensor fn_with_default_input_meta(const Tensor& tensor, const int64_t i64) {
return tensor.clone();
}
std::tuple<Tensor, Tensor> fn_with_tuple_output_meta(
const Tensor& tensor,
const int64_t i64) {
return {tensor.clone(), tensor.clone()};
}
std::vector<Tensor> fn_with_list_output_meta(
TensorList tensors,
const int64_t i64) {
std::vector<Tensor> outputs;
for (auto& t : tensors) {
outputs.push_back(t.clone());
}
return outputs;
}
std::tuple<Tensor, std::vector<Tensor>> fn_with_mix_outputs_meta(
const Tensor& tensor,
TensorList tensors) {
std::vector<Tensor> outputs;
for (auto& t : tensors) {
outputs.push_back(t.clone());
}
return {tensor.clone(), outputs};
}
std::tuple<Tensor, Tensor> fn_with_input_mutation_meta(
Tensor& t0,
const Tensor& t1,
Tensor& t2) {
return {t1.clone(), t1.clone()};
}
} // namespace at
TORCH_LIBRARY(aoti_custom_ops, m) {
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
m.def(
"fn_with_all_inputs(Tensor tensor, "
"Tensor[] tensors, "
"Tensor?[] optional_tensors, "
"bool b8, bool[] b8s, "
"int i64, int[] i64s, "
"SymInt symint, SymInt[] symints, "
"float f64, float[] f64s, "
"Scalar scalar, Scalar[] scalars, "
"str string, str[] strings, "
"ScalarType dtype, "
"MemoryFormat memory_format, "
"Layout layout, "
"Device device, "
"*, "
"Tensor? o_tensor, Tensor[]? o_tensors, "
"bool? o_b8, bool[]? o_b8s, "
"int? o_i64, int[]? o_i64s, "
"SymInt? o_symint, SymInt[]? o_symints, "
"float? o_f64, float[]? o_f64s, "
"Scalar? o_scalar, Scalar[]? o_scalars, "
"str? o_string, str[]? o_strings, "
"ScalarType? o_dtype, "
"MemoryFormat? o_memory_format, "
"Layout? o_layout, "
"Device? o_device) -> Tensor");
m.def("fn_with_default_input(Tensor t, int i=3) -> Tensor");
m.def("fn_with_tuple_output(Tensor t, int i) -> (Tensor, Tensor)");
m.def("fn_with_list_output(Tensor[] tensors, int i) -> Tensor[]");
m.def(
"fn_with_mix_outputs(Tensor t, Tensor[] tensors) -> (Tensor, Tensor[])");
m.def(
"fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
}
TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
m.impl("custom_add", at::custom_add_impl);
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl);
m.impl("fn_with_default_input", at::fn_with_default_input_impl);
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl);
m.impl("fn_with_list_output", at::fn_with_list_output_impl);
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
}
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta);
m.impl("fn_with_default_input", at::fn_with_default_input_meta);
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta);
m.impl("fn_with_list_output", at::fn_with_list_output_meta);
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
}

View File

@ -30,9 +30,11 @@ from torch.testing._internal.common_quantization import (
)
from torch.testing._internal.common_utils import (
DeterministicGuard,
find_library_location,
IS_CI,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
skipIfRocm,
TEST_WITH_ROCM,
@ -2469,6 +2471,18 @@ class AOTInductorTestsTemplate:
model.weight += 1
self.check_model(model, example_inputs)
def test_custom_op_add(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aoti_custom_ops.custom_add(x, y)
m = M().to(device=self.device)
args = (
torch.randn(3, 3, device=self.device),
torch.randn(3, 3, device=self.device),
)
self.check_model(m, args)
def test_triton_kernel_extern_kernel_arg(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
@ -3040,7 +3054,21 @@ class AOTInductorTestsTemplate:
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
class AOTInductorTestABICompatibleCpu(TestCase):
class AOTITestCase(TestCase):
def setUp(self):
if IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library("//caffe2/test/inductor:custom_ops")
elif IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
else:
lib_file_path = find_library_location("libaoti_custom_ops.so")
if IS_WINDOWS:
lib_file_path = find_library_location("aoti_custom_ops.dll")
torch.ops.load_library(str(lib_file_path))
super().setUp()
class AOTInductorTestABICompatibleCpu(AOTITestCase):
device = "cpu"
abi_compatible = True
check_model = check_model
@ -3199,6 +3227,7 @@ CPU_TEST_FAILURES = {
"test_while_loop_with_outer_buffers": fail_stack_allocation(is_skip=True),
# TODO: use of undeclared identifier 'float8_e4m3fn' and 'half'
"test_fp8": fail_minimal_arrayref_interface(is_skip=True),
"test_custom_op_add": fail_minimal_arrayref_interface(is_skip=True),
}
# test_failures, xfail by default, set is_skip=True to skip
@ -3213,6 +3242,7 @@ CUDA_TEST_FAILURES = {
"test_runtime_checks_shape_failed": fail_non_abi_compatible_cuda(is_skip=True),
# quantized unsupported for GPU
"test_quantized_linear": fail_cuda(is_skip=True),
"test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True),
}
@ -3268,7 +3298,7 @@ copy_tests(
)
class AOTInductorTestABICompatibleCpuWithStackAllocation(TestCase):
class AOTInductorTestABICompatibleCpuWithStackAllocation(AOTITestCase):
device = "cpu"
abi_compatible = True
check_model = check_model
@ -3306,7 +3336,7 @@ copy_tests(
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
class AOTInductorTestABICompatibleCuda(TestCase):
class AOTInductorTestABICompatibleCuda(AOTITestCase):
device = "cuda"
abi_compatible = True
check_model = check_model
@ -3328,7 +3358,7 @@ copy_tests(
IS_FBCODE or sys.platform == "darwin",
"NonABI mode should not be used in fbcode nor on MacOS",
)
class AOTInductorTestNonABICompatibleCpu(TestCase):
class AOTInductorTestNonABICompatibleCpu(AOTITestCase):
device = "cpu"
abi_compatible = False
check_model = check_model
@ -3355,6 +3385,7 @@ copy_tests(
"test_runtime_checks_shape_failed": TestFailure(
("non_abi_compatible_cpu",), is_skip=True
),
"test_custom_op_add": TestFailure(("non_abi_compatible_cpu",), is_skip=True),
},
)
@ -3363,7 +3394,7 @@ copy_tests(
IS_FBCODE or sys.platform == "darwin",
"NonABI mode should not be used in fbcode nor on MacOS",
)
class AOTInductorTestNonABICompatibleCuda(TestCase):
class AOTInductorTestNonABICompatibleCuda(AOTITestCase):
device = "cuda"
abi_compatible = False
check_model = check_model

View File

@ -47,6 +47,15 @@ class AOTIRunnerUtil:
restore_fqn=False,
)
if IS_FBCODE:
from deeplearning.aot_inductor.extern_node_thrift_serializer import (
thrift_serializer,
)
if options is None:
options = {}
options["extern_node_serializer"] = thrift_serializer
with torch.no_grad():
so_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type]

View File

@ -1,20 +1,18 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
cc_library(
name = "nlohmann",
hdrs = glob(["include/**/*.hpp"]),
includes = [
"/",
],
cc_library(name = "nlohmann",
includes = ["include"],
deps = ["nlohmann-internal"],
visibility = ["//visibility:public"],
)
cc_import(name = "nlohmann-internal",
hdrs = glob(["include/**/*.hpp"]),
visibility = ["//visibility:private"],
)
cc_library(
name = "nlohmann_single_include",
hdrs = glob(["single_include/nlohmann/*.hpp"]),
includes = [
"/",
],
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,15 @@
from dataclasses import dataclass
from typing import List
from torch._export.serde.schema import Node
@dataclass
class ExternKernelNode:
name: str
node: Node
@dataclass
class ExternKernelNodes:
nodes: List[ExternKernelNode]

View File

@ -1797,7 +1797,7 @@ class AotCodeCompiler:
with lock:
# Currently, this only support serializing extern nodes in fbcode
# Eventually, we should also have a serializer for OSS.
if config.is_fbcode() and serialized_extern_kernel_nodes:
if serialized_extern_kernel_nodes:
output_json = os.path.splitext(input_path)[0] + ".json"
with open(output_json, "w") as f:
f.write(serialized_extern_kernel_nodes)

View File

@ -2111,19 +2111,19 @@ class CppWrapperCpu(WrapperCodeGen):
if isinstance(output_args, str):
output_args = [output_args]
if config.is_fbcode():
if V.graph.aot_mode and config.abi_compatible:
assert op_overload is not None
assert raw_args is not None
assert outputs is not None
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor(
cpp_kernel_key,
op_overload,
raw_args,
output_args,
)
else:
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
buf_name,
python_kernel_name,
cpp_kernel_name,
@ -2207,7 +2207,7 @@ if (custom_op_wrapper.get() == NULL) {
lines += f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(raw_arg, arg_type)});\n"
return lines
def generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
def generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
self,
buf_name: str,
python_kernel_name: str,
@ -2220,7 +2220,7 @@ if (custom_op_wrapper.get() == NULL) {
raw_args=None,
output_args: Optional[List[str]] = None,
):
if V.graph.aot_mode or not config.abi_compatible:
if not config.abi_compatible:
# Will update this to use an OSS version ProxyExecutor
if cpp_kernel_key not in self.extern_call_ops:
self.writeline(
@ -2288,7 +2288,7 @@ if (py_{buf_name}.get() == NULL) {{
)
self.writelines(scope_gil_acquire)
def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor(
self,
cpp_kernel_key,
op_overload,

View File

@ -0,0 +1,26 @@
import json
from typing import List
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode
def serialize_extern_kernel_node(
extern_kernel_node: inductor_ExternKernelNode,
) -> ExternKernelNode:
assert isinstance(extern_kernel_node.node, Node)
return ExternKernelNode(
name=extern_kernel_node.name,
node=extern_kernel_node.node,
)
def extern_node_json_serializer(
extern_kernel_nodes: List[inductor_ExternKernelNode],
) -> str:
serialized_nodes = ExternKernelNodes(
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
)
return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder)

View File

@ -380,9 +380,15 @@ class GraphLowering(torch.fx.Interpreter):
self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
# See `ProxyExecutor Design Note` in ir.py for more details
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
self.extern_node_serializer: Optional[
Callable[[List[ir.ExternKernelNode]], Any]
] = extern_node_serializer
from torch._inductor.extern_node_serializer import extern_node_json_serializer
self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = (
extern_node_serializer
if config.is_fbcode() and extern_node_serializer
else extern_node_json_serializer
)
self.current_node: torch.fx.Node = None # type: ignore[assignment]
self.lists: Dict[str, List[str]] = {}
self.mutated_inputs: Set[str] = set()
@ -1808,11 +1814,7 @@ class GraphLowering(torch.fx.Interpreter):
output_code_log.debug("Output code: \n%s", code)
serialized_extern_kernel_nodes = None
if (
config.is_fbcode()
and self.extern_kernel_nodes
and self.extern_node_serializer
):
if self.extern_kernel_nodes:
serialized_extern_kernel_nodes = self.extern_node_serializer(
self.extern_kernel_nodes
)

View File

@ -2,8 +2,18 @@
#include <ATen/DynamicLibrary.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
// TODO: Investigate why this is necessary, but fixes build problems in FRL
#if __has_include("filesystem")
#include <filesystem>
namespace fs = std::filesystem;
#else
#include <experimental/filesystem>
namespace fs = std::experimental::filesystem;
#endif
namespace torch::inductor {
AOTIModelContainerRunner::AOTIModelContainerRunner(
@ -46,6 +56,17 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
get_call_spec_func_ = reinterpret_cast<decltype(get_call_spec_func_)>(
model_so_->sym("AOTInductorModelContainerGetCallSpec"));
// Hack to find the json file name from the model so file
size_t lastindex = model_so_path.find_last_of(".");
std::string json_filename = model_so_path.substr(0, lastindex) + ".json";
if (fs::exists(json_filename)) {
proxy_executor_ = std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
json_filename, device_str == "cpu");
proxy_executor_handle_ =
reinterpret_cast<AOTIProxyExecutorHandle>(proxy_executor_.get());
}
AOTI_RUNTIME_ERROR_CODE_CHECK(create_func_(
&container_handle_,
num_models,

View File

@ -3,6 +3,7 @@
#include <ATen/Tensor.h>
#include <torch/csrc/inductor/aoti_runtime/interface.h>
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
// Forward declare DynamicLibrary
namespace at {
@ -75,9 +76,10 @@ class TORCH_API AOTIModelContainerRunner {
AOTInductorModelContainerHandle container_handle_ = nullptr;
// TODO: need an OSS proxy executor implementation. For now,
// proxy_executor_handle_ will always be nullptr.
AOTIProxyExecutorHandle proxy_executor_handle_ = nullptr;
AOTIProxyExecutorHandle proxy_executor_handle_;
private:
std::unique_ptr<torch::aot_inductor::ProxyExecutor> proxy_executor_;
};
} // namespace torch::inductor

View File

@ -0,0 +1,262 @@
#include <nlohmann/json.hpp>
#include <fstream>
#include <iostream>
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
namespace {
at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) {
return reinterpret_cast<at::Tensor*>(handle);
}
} // namespace
namespace torch::aot_inductor {
void OSSProxyExecutor::prefill_stack_with_static_arguments(
int index,
at::TypePtr schema_arg_type,
const nlohmann::json& serialized_arg,
OpKernel& op_kernel) {
auto& stack = op_kernel.stack_;
auto& dynamic_args = op_kernel.dynamic_args_;
TORCH_CHECK(serialized_arg.size() == 1);
std::string serialized_arg_type = serialized_arg.begin().key();
auto& serialized_arg_val = serialized_arg.begin().value();
switch (schema_arg_type->kind()) {
case c10::TypeKind::TensorType: {
TORCH_CHECK(serialized_arg_type == "as_tensor");
stack.emplace_back();
dynamic_args.emplace_back(
index, DynamicArgType::TensorType, 1, std::move(serialized_arg_val));
break;
}
// TODO: handle the other input types
default:
TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type);
}
}
// Populates op_kernel.stack_, op_kernel.dynamic_args_
void OSSProxyExecutor::get_input_info_from_serialized(
const std::vector<c10::Argument>& schema_args,
const nlohmann::json& serialized_node,
OpKernel& op_kernel) {
int index = 0;
for (const auto& named_argument : serialized_node["inputs"]) {
const auto& arg = named_argument["arg"];
auto& schema_arg = schema_args[index];
prefill_stack_with_static_arguments(
index++, schema_arg.real_type(), arg, op_kernel);
}
// TODO: prefill default values
}
// Populates op_kernel.outputs_
void OSSProxyExecutor::get_output_info_from_serialized(
const std::vector<c10::Argument>& schema_returns,
const nlohmann::json& serialized_node,
OpKernel& op_kernel) {
std::vector<DynamicArg>& outputs = op_kernel.outputs_;
TORCH_CHECK(
schema_returns.size() == serialized_node["outputs"].size(),
"Serialized node doesn't match op's schema outputs.");
size_t output_index = 0;
for (const auto& serialized_output : serialized_node["outputs"]) {
TORCH_CHECK(serialized_output.size() == 1);
std::string serialized_output_type = serialized_output.begin().key();
auto& serialized_output_val = serialized_output.begin().value();
auto& schema_return = schema_returns[output_index];
at::TypePtr schema_return_type = schema_return.real_type();
switch (schema_return_type->kind()) {
case c10::TypeKind::TensorType: {
TORCH_CHECK(
serialized_output_type == "as_tensor",
serialized_node["target"],
" got serialized_output_type of ",
serialized_output_type);
outputs.emplace_back(
output_index,
DynamicArgType::TensorType,
1,
serialized_output_type);
break;
}
case c10::TypeKind::ListType: {
if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) {
TORCH_CHECK(
serialized_output_type == "as_tensors",
serialized_node["target"],
" got serialized_output_type of ",
serialized_output_type);
outputs.emplace_back(
output_index,
DynamicArgType::ListTensorType,
serialized_output_val.size(),
serialized_output_type);
} else {
TORCH_CHECK(
false,
"Unsupported return list type ",
schema_return_type->repr_str());
}
break;
}
default: {
TORCH_CHECK(
false, "Unsupported return type ", schema_return_type->repr_str());
}
}
output_index++;
}
}
OSSProxyExecutor::OSSProxyExecutor(const std::string& json_path, bool is_cpu) {
if (is_cpu) {
device_ = std::make_unique<c10::Device>(c10::DeviceType::CPU);
} else {
int device_idx = -1;
device_ = std::make_unique<c10::Device>(c10::DeviceType::CUDA, device_idx);
}
std::string extern_kernel_nodes_serialized;
std::ifstream json_file(json_path);
TORCH_CHECK(json_file.is_open());
// Parse file into a json object
nlohmann::json json_obj;
json_file >> json_obj;
// Access data
for (auto const& serialized_extern_node : json_obj["nodes"]) {
auto const& serialized_node = serialized_extern_node["node"];
const std::string& target = serialized_node["target"];
std::string opName;
std::string overloadName;
size_t pos = target.find('.');
if (pos == std::string::npos) {
opName = target;
overloadName = "";
} else {
// There should be no more periods
size_t pos2 = target.find('.', pos);
TORCH_CHECK(pos2 == std::string::npos);
opName = target.substr(0, pos);
overloadName = target.substr(pos + 1, target.length() - pos);
}
c10::OperatorHandle op_handle =
c10::Dispatcher::singleton().findSchemaOrThrow(
opName.c_str(), overloadName.c_str());
const c10::FunctionSchema& schema = op_handle.schema();
const auto& schema_args = schema.arguments();
const auto& schema_returns = schema.returns();
OpKernel op_kernel(target, op_handle);
get_input_info_from_serialized(schema_args, serialized_node, op_kernel);
get_output_info_from_serialized(schema_returns, serialized_node, op_kernel);
op_kernels_.emplace_back(std::move(op_kernel));
}
}
void OSSProxyExecutor::call_function(
int extern_node_index,
int num_ints,
int64_t* flatten_int_args,
int num_tensors,
AtenTensorHandle* flatten_tensor_args) {
TORCH_CHECK(
extern_node_index < static_cast<int>(op_kernels_.size()),
"Invalid extern node index");
OpKernel& op_kernel = op_kernels_[extern_node_index];
std::vector<c10::IValue> stack = op_kernel.stack_;
auto& dynamic_args = op_kernel.dynamic_args_;
int tensor_id = 0;
int int_id = 0;
for (auto& dynamic_arg : dynamic_args) {
int arg_index = dynamic_arg.arg_index;
DynamicArgType dynamic_arg_type = dynamic_arg.arg_type;
int length = dynamic_arg.length;
if (length == 0) {
continue;
}
switch (dynamic_arg_type) {
case DynamicArgType::TensorType: {
at::Tensor* tensor =
tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
stack[arg_index] = *tensor;
break;
}
// TODO: handle other dynamic arg types
default:
TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type);
}
}
int num_output_tensors = op_kernel.num_output_tensors();
TORCH_CHECK(
tensor_id == num_tensors - num_output_tensors,
"Mismatch between tensors consumed and num of input tensor, got tensor_id = .",
tensor_id,
", expected num = ",
num_tensors - num_output_tensors);
TORCH_CHECK(
int_id == num_ints,
"Mismatch between ints consumed and num_ints, got int_id = ",
int_id,
", num_ints = ",
num_ints);
// Call the op with the prepared stack.
const c10::OperatorHandle& op = op_kernel.op_handle_;
op.callBoxed(stack);
const c10::FunctionSchema& schema = op.schema();
const auto& schema_returns = schema.returns();
TORCH_CHECK(op_kernel.outputs_.size() == stack.size());
// TODO: what about optional outputs? This assert may not hold
TORCH_CHECK(stack.size() == schema_returns.size());
int index = 0;
for (const auto& schema_return : schema_returns) {
if (schema_return.type()->kind() == c10::TypeKind::TensorType) {
at::Tensor* tensor =
tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]);
*tensor = stack[index++].toTensor();
// TODO: handle tensor list returns
} else {
TORCH_CHECK(
false,
"NYI: Unsupported return type for schema: ",
schema_return.type()->repr_str());
}
}
TORCH_CHECK(
tensor_id == num_tensors,
"Mismatch between tensors consumed and num_tensors, got tensor_id = ",
tensor_id,
", expected num = ",
num_tensors);
}
} // namespace torch::aot_inductor

View File

@ -0,0 +1,101 @@
#pragma once
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <nlohmann/json.hpp>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
#include <iostream>
namespace torch::aot_inductor {
enum class DynamicArgType : int {
TensorType = 0,
ListTensorType = 1,
ListOptionalTensorType = 2,
IntType = 3,
ListIntType = 4,
};
inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) {
os << static_cast<int>(arg_type);
return os;
}
inline bool isTensorType(DynamicArgType arg_type) {
return arg_type == DynamicArgType::TensorType ||
arg_type == DynamicArgType::ListTensorType ||
arg_type == DynamicArgType::ListOptionalTensorType;
}
struct DynamicArg {
DynamicArg(
int arg_index,
DynamicArgType arg_type,
int length,
nlohmann::json serialized_arg_val)
: arg_index(arg_index),
arg_type(arg_type),
length(length),
serialized_arg_val(std::move(serialized_arg_val)) {}
int arg_index;
DynamicArgType arg_type;
int length;
nlohmann::json serialized_arg_val;
};
struct OpKernel {
OpKernel(const std::string& target, const c10::OperatorHandle& op_handle)
: target_(target), op_handle_(op_handle) {}
std::string target_;
c10::OperatorHandle op_handle_;
std::vector<DynamicArg> dynamic_args_;
std::vector<DynamicArg> outputs_;
std::vector<c10::IValue> stack_;
int num_output_tensors() const {
int num_output_tensors = 0;
for (const auto& output : outputs_) {
if (isTensorType(output.arg_type)) {
num_output_tensors += output.length;
}
}
return num_output_tensors;
}
};
class OSSProxyExecutor : public ProxyExecutor {
public:
explicit OSSProxyExecutor(const std::string& json_path, bool is_cpu);
void call_function(
int extern_node_index,
int num_ints,
int64_t* flatten_int_args,
int num_tensors,
AtenTensorHandle* flatten_tensor_args) override;
private:
void prefill_stack_with_static_arguments(
int index,
at::TypePtr schema_arg_type,
const nlohmann::json& thrift_arg,
OpKernel& op_kernel);
void get_input_info_from_serialized(
const std::vector<c10::Argument>& schema_args,
const nlohmann::json& serialized_node,
OpKernel& op_kernel);
void get_output_info_from_serialized(
const std::vector<c10::Argument>& schema_returns,
const nlohmann::json& serialized_node,
OpKernel& op_kernel);
std::vector<OpKernel> op_kernels_;
std::unique_ptr<c10::Device> device_;
};
} // namespace torch::aot_inductor

View File

@ -5,6 +5,7 @@
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>