From b90aa185691ec8ac49b254a0c3660c44ed7a141c Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 24 Jul 2024 20:29:54 +0000 Subject: [PATCH] [aoti] Add initial custom op support (#127034) Re-land of https://github.com/pytorch/pytorch/pull/125242 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127034 Approved by: https://github.com/malfet --- build_variables.bzl | 1 + caffe2/CMakeLists.txt | 1 + test/inductor/CMakeLists.txt | 8 + test/inductor/custom_ops.cpp | 364 ++++++++++++++++++ test/inductor/test_aot_inductor.py | 41 +- test/inductor/test_aot_inductor_utils.py | 9 + third_party/nlohmann.BUILD | 18 +- torch/_export/serde/aoti_schema.py | 15 + torch/_inductor/codecache.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 12 +- torch/_inductor/extern_node_serializer.py | 26 ++ torch/_inductor/graph.py | 18 +- .../aoti_runner/model_container_runner.cpp | 21 + .../aoti_runner/model_container_runner.h | 8 +- .../aoti_torch/oss_proxy_executor.cpp | 262 +++++++++++++ .../inductor/aoti_torch/oss_proxy_executor.h | 101 +++++ .../csrc/inductor/aoti_torch/shim_common.cpp | 1 + 17 files changed, 875 insertions(+), 33 deletions(-) create mode 100644 test/inductor/CMakeLists.txt create mode 100644 test/inductor/custom_ops.cpp create mode 100644 torch/_export/serde/aoti_schema.py create mode 100644 torch/_inductor/extern_node_serializer.py create mode 100644 torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp create mode 100644 torch/csrc/inductor/aoti_torch/oss_proxy_executor.h diff --git a/build_variables.bzl b/build_variables.bzl index 7c01ad401bb0..49db40e02d0d 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -471,6 +471,7 @@ inductor_core_resources = [ "torch/csrc/inductor/aoti_torch/shim_common.cpp", "torch/csrc/inductor/aoti_torch/tensor_converter.cpp", "torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp", + "torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp", "torch/csrc/inductor/inductor_ops.cpp", ] diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index cb10de71bf15..511a55828f4a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1213,6 +1213,7 @@ if(BUILD_TEST) ) else() add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) + add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) add_subdirectory( ${TORCH_ROOT}/test/cpp/tensorexpr ${CMAKE_BINARY_DIR}/test_tensorexpr diff --git a/test/inductor/CMakeLists.txt b/test/inductor/CMakeLists.txt new file mode 100644 index 000000000000..c27e8c772940 --- /dev/null +++ b/test/inductor/CMakeLists.txt @@ -0,0 +1,8 @@ +# Build separate libraries the define custom classes/operators used from our Python tests. +# These are intended to be used with torch.ops.load_library() in our Python test suite. +add_library(aoti_custom_ops SHARED custom_ops.cpp) +target_link_libraries(aoti_custom_ops torch) + +if(INSTALL_TEST) + install(TARGETS aoti_custom_ops DESTINATION lib) +endif() diff --git a/test/inductor/custom_ops.cpp b/test/inductor/custom_ops.cpp new file mode 100644 index 000000000000..da27d79aca7e --- /dev/null +++ b/test/inductor/custom_ops.cpp @@ -0,0 +1,364 @@ +#include + +#include +#include +#include + +namespace at { + +Tensor custom_add_impl(Tensor t1, Tensor t2) { + return t1 + t2; +} + +Tensor fn_with_all_inputs_impl( + const Tensor& tensor, + const c10::List& tensors, + const c10::List>& optional_tensors, + const bool b8, + const c10::List& b8s, + const int64_t i64, + const c10::List& i64s, + const int64_t& symint, + const IntArrayRef symints, + const double f64, + const c10::List& f64s, + const at::Scalar& scalar, + at::ArrayRef scalars, + const std::string& string, + const std::vector& strings, + const c10::ScalarType& dtype, + const MemoryFormat& memory_format, + const Layout& layout, + const Device& device, + // optional + const std::optional& o_tensor, + const std::optional>& o_tensors, + const std::optional& o_b8, + const std::optional>& o_b8s, + const std::optional& o_i64, + const std::optional>& o_i64s, + const std::optional& o_symint, + const std::optional& o_symints, + const std::optional& o_f64, + const std::optional>& o_f64s, + const std::optional& o_scalar, + const std::optional>& o_scalars, + const std::optional& o_string, + const std::optional>& o_strings, + const std::optional& o_dtype, + const std::optional& o_memory_format, + const std::optional& o_layout, + const std::optional& o_device) { + std::cout << "tensor shape: " << tensor.sizes() << std::endl; + + std::cout << "tensors shape: "; + for (auto t : tensors) { + std::cout << t.get().toTensor().sizes() << ", "; + } + std::cout << std::endl; + + std::cout << "optional tensors shape: "; + for (auto t : optional_tensors) { + if (t.get().toOptional().has_value()) { + std::cout << t.get().toTensor().sizes() << ", "; + } else { + std::cout << "None, "; + } + } + std::cout << std::endl; + + std::cout << "b8 " << c10::IValue(b8) << std::endl; + std::cout << "b8s " << c10::IValue(b8s) << std::endl; + std::cout << "i64 " << c10::IValue(i64) << std::endl; + std::cout << "i64s " << c10::IValue(i64s) << std::endl; + std::cout << "symint " << c10::IValue(symint) << std::endl; + std::cout << "symints " << c10::IValue(symints) << std::endl; + std::cout << "f64 " << c10::IValue(f64) << std::endl; + std::cout << "f64s " << c10::IValue(f64s) << std::endl; + std::cout << "scalar " << c10::IValue(scalar) << std::endl; + std::cout << "scalars " << c10::IValue(scalars) << std::endl; + std::cout << "string " << c10::IValue(string) << std::endl; + std::cout << "strings " << c10::IValue(strings) << std::endl; + std::cout << "dtype " << c10::IValue(dtype) << std::endl; + std::cout << "memory_format " << c10::IValue(memory_format) << std::endl; + std::cout << "layout " << c10::IValue(layout) << std::endl; + std::cout << "device " << c10::IValue(device) << std::endl; + + std::cout << "o_tensor " + << (o_tensor.has_value() ? c10::IValue(o_tensor.value().sizes()) + : "None") + << std::endl; + + std::cout << "o_tensors shape: "; + if (o_tensors.has_value()) { + for (auto t : o_tensors.value()) { + std::cout << t.get().toTensor().sizes() << ", "; + } + } else { + std::cout << "None"; + } + std::cout << std::endl; + + std::cout << "o_b8 " + << (o_b8.has_value() ? c10::IValue(o_b8.value()) : "None") + << std::endl; + std::cout << "o_b8s " + << (o_b8s.has_value() ? c10::IValue(o_b8s.value()) : "None") + << std::endl; + std::cout << "o_i64 " + << (o_i64.has_value() ? c10::IValue(o_i64.value()) : "None") + << std::endl; + std::cout << "o_i64s " + << (o_i64s.has_value() ? c10::IValue(o_i64s.value()) : "None") + << std::endl; + std::cout << "o_symint " + << (o_symint.has_value() ? c10::IValue(o_symint.value()) : "None") + << std::endl; + std::cout << "o_symints " + << (o_symints.has_value() ? c10::IValue(o_symints.value()) : "None") + << std::endl; + std::cout << "o_f64 " + << (o_f64.has_value() ? c10::IValue(o_f64.value()) : "None") + << std::endl; + std::cout << "o_f64s " + << (o_f64s.has_value() ? c10::IValue(o_f64s.value()) : "None") + << std::endl; + std::cout << "o_scalar " + << (o_scalar.has_value() ? c10::IValue(o_scalar.value()) : "None") + << std::endl; + std::cout << "o_scalars " + << (o_scalars.has_value() ? c10::IValue(o_scalars.value()) : "None") + << std::endl; + std::cout << "o_string " + << (o_string.has_value() ? c10::IValue(o_string.value()) : "None") + << std::endl; + std::cout << "o_strings " + << (o_strings.has_value() ? c10::IValue(o_strings.value()) : "None") + << std::endl; + std::cout << "o_dtype " + << (o_dtype.has_value() ? c10::IValue(o_dtype.value()) : "None") + << std::endl; + std::cout << "o_memory_format " + << (o_memory_format.has_value() + ? c10::IValue(o_memory_format.value()) + : "None") + << std::endl; + std::cout << "o_layout " + << (o_layout.has_value() ? c10::IValue(o_layout.value()) : "None") + << std::endl; + std::cout << "o_device " + << (o_device.has_value() ? c10::IValue(o_device.value()) : "None") + << std::endl; + + int64_t int_hash = 0; + int_hash ^= i64; + for (auto i : i64s) { + int_hash ^= i; + } + if (o_i64.has_value()) { + int_hash ^= o_i64.value(); + } + if (o_i64s.has_value()) { + for (auto i : o_i64s.value()) { + int_hash ^= i; + } + } + + int_hash ^= symint; + for (auto i : symints) { + int_hash ^= i; + } + if (o_symint.has_value()) { + int_hash ^= o_symint.value(); + } + if (o_symints.has_value()) { + for (auto i : o_symints.value()) { + int_hash ^= i; + } + } + + return tensor + int_hash; +} + +Tensor fn_with_default_input_impl(const Tensor& tensor, const int64_t i64) { + return tensor + i64; +} + +std::tuple fn_with_tuple_output_impl( + const Tensor& tensor, + const int64_t i64) { + return {tensor + i64, tensor - i64}; +} + +std::vector fn_with_list_output_impl( + TensorList tensors, + const int64_t i64) { + std::vector outputs; + for (auto& t : tensors) { + outputs.emplace_back(t + i64); + } + return outputs; +} + +std::tuple> fn_with_mix_outputs_impl( + const Tensor& tensor, + TensorList tensors) { + std::vector outputs; + for (auto& t : tensors) { + outputs.emplace_back(t + 2); + } + return {tensor + 1, outputs}; +} + +std::tuple fn_with_input_mutation_impl( + Tensor& t0, + const Tensor& t1, + Tensor& t2) { + t0.add_(1); + t2.sub_(1); + return {t1 + 1, t1 + 2}; +} + +// NOLINTBEGIN(clang-diagnostic-unused-parameter) +Tensor fn_with_all_inputs_meta( + const Tensor& tensor, + const c10::List& tensors, + const c10::List>& optional_tensors, + const bool b8, + const c10::List& b8s, + const int64_t i64, + const c10::List& i64s, + const c10::SymInt& symint, + c10::SymIntArrayRef symints, + const double f64, + const c10::List& f64s, + const at::Scalar& scalar, + at::ArrayRef scalars, + const std::string& string, + const std::vector& strings, + const c10::ScalarType& dtype, + const MemoryFormat& memory_format, + const Layout& layout, + const Device& device, + // optional + const std::optional& o_tensor, + const std::optional>& o_tensors, + const std::optional& o_b8, + const std::optional>& o_b8s, + const std::optional& o_i64, + const std::optional>& o_i64s, + const std::optional& o_symint, + at::OptionalSymIntArrayRef o_symints, + const std::optional& o_f64, + const std::optional>& o_f64s, + const std::optional& o_scalar, + const std::optional>& o_scalars, + const std::optional& o_string, + const std::optional>& o_strings, + const std::optional& o_dtype, + const std::optional& o_memory_format, + const std::optional& o_layout, + const std::optional& o_device) { + return tensor; +} + +Tensor fn_with_default_input_meta(const Tensor& tensor, const int64_t i64) { + return tensor.clone(); +} + +std::tuple fn_with_tuple_output_meta( + const Tensor& tensor, + const int64_t i64) { + return {tensor.clone(), tensor.clone()}; +} + +std::vector fn_with_list_output_meta( + TensorList tensors, + const int64_t i64) { + std::vector outputs; + for (auto& t : tensors) { + outputs.push_back(t.clone()); + } + return outputs; +} + +std::tuple> fn_with_mix_outputs_meta( + const Tensor& tensor, + TensorList tensors) { + std::vector outputs; + for (auto& t : tensors) { + outputs.push_back(t.clone()); + } + return {tensor.clone(), outputs}; +} + +std::tuple fn_with_input_mutation_meta( + Tensor& t0, + const Tensor& t1, + Tensor& t2) { + return {t1.clone(), t1.clone()}; +} + +} // namespace at + +TORCH_LIBRARY(aoti_custom_ops, m) { + m.def("custom_add(Tensor t1, Tensor t2) -> Tensor"); + m.def( + "fn_with_all_inputs(Tensor tensor, " + "Tensor[] tensors, " + "Tensor?[] optional_tensors, " + "bool b8, bool[] b8s, " + "int i64, int[] i64s, " + "SymInt symint, SymInt[] symints, " + "float f64, float[] f64s, " + "Scalar scalar, Scalar[] scalars, " + "str string, str[] strings, " + "ScalarType dtype, " + "MemoryFormat memory_format, " + "Layout layout, " + "Device device, " + "*, " + "Tensor? o_tensor, Tensor[]? o_tensors, " + "bool? o_b8, bool[]? o_b8s, " + "int? o_i64, int[]? o_i64s, " + "SymInt? o_symint, SymInt[]? o_symints, " + "float? o_f64, float[]? o_f64s, " + "Scalar? o_scalar, Scalar[]? o_scalars, " + "str? o_string, str[]? o_strings, " + "ScalarType? o_dtype, " + "MemoryFormat? o_memory_format, " + "Layout? o_layout, " + "Device? o_device) -> Tensor"); + + m.def("fn_with_default_input(Tensor t, int i=3) -> Tensor"); + + m.def("fn_with_tuple_output(Tensor t, int i) -> (Tensor, Tensor)"); + + m.def("fn_with_list_output(Tensor[] tensors, int i) -> Tensor[]"); + + m.def( + "fn_with_mix_outputs(Tensor t, Tensor[] tensors) -> (Tensor, Tensor[])"); + + m.def( + "fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)"); + +} + +TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) { + m.impl("custom_add", at::custom_add_impl); + m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl); + m.impl("fn_with_default_input", at::fn_with_default_input_impl); + m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl); + m.impl("fn_with_list_output", at::fn_with_list_output_impl); + m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl); + m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl); +} + +TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) { + m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta); + m.impl("fn_with_default_input", at::fn_with_default_input_meta); + m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta); + m.impl("fn_with_list_output", at::fn_with_list_output_meta); + m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta); + m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta); +} diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index f32bb5574493..16a0f777368b 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -30,9 +30,11 @@ from torch.testing._internal.common_quantization import ( ) from torch.testing._internal.common_utils import ( DeterministicGuard, + find_library_location, IS_CI, IS_FBCODE, IS_MACOS, + IS_SANDCASTLE, IS_WINDOWS, skipIfRocm, TEST_WITH_ROCM, @@ -2469,6 +2471,18 @@ class AOTInductorTestsTemplate: model.weight += 1 self.check_model(model, example_inputs) + def test_custom_op_add(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aoti_custom_ops.custom_add(x, y) + + m = M().to(device=self.device) + args = ( + torch.randn(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + self.check_model(m, args) + def test_triton_kernel_extern_kernel_arg(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") @@ -3040,7 +3054,21 @@ class AOTInductorTestsTemplate: common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) -class AOTInductorTestABICompatibleCpu(TestCase): +class AOTITestCase(TestCase): + def setUp(self): + if IS_SANDCASTLE or IS_FBCODE: + torch.ops.load_library("//caffe2/test/inductor:custom_ops") + elif IS_MACOS: + raise unittest.SkipTest("non-portable load_library call used in test") + else: + lib_file_path = find_library_location("libaoti_custom_ops.so") + if IS_WINDOWS: + lib_file_path = find_library_location("aoti_custom_ops.dll") + torch.ops.load_library(str(lib_file_path)) + super().setUp() + + +class AOTInductorTestABICompatibleCpu(AOTITestCase): device = "cpu" abi_compatible = True check_model = check_model @@ -3199,6 +3227,7 @@ CPU_TEST_FAILURES = { "test_while_loop_with_outer_buffers": fail_stack_allocation(is_skip=True), # TODO: use of undeclared identifier 'float8_e4m3fn' and 'half' "test_fp8": fail_minimal_arrayref_interface(is_skip=True), + "test_custom_op_add": fail_minimal_arrayref_interface(is_skip=True), } # test_failures, xfail by default, set is_skip=True to skip @@ -3213,6 +3242,7 @@ CUDA_TEST_FAILURES = { "test_runtime_checks_shape_failed": fail_non_abi_compatible_cuda(is_skip=True), # quantized unsupported for GPU "test_quantized_linear": fail_cuda(is_skip=True), + "test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True), } @@ -3268,7 +3298,7 @@ copy_tests( ) -class AOTInductorTestABICompatibleCpuWithStackAllocation(TestCase): +class AOTInductorTestABICompatibleCpuWithStackAllocation(AOTITestCase): device = "cpu" abi_compatible = True check_model = check_model @@ -3306,7 +3336,7 @@ copy_tests( @unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") -class AOTInductorTestABICompatibleCuda(TestCase): +class AOTInductorTestABICompatibleCuda(AOTITestCase): device = "cuda" abi_compatible = True check_model = check_model @@ -3328,7 +3358,7 @@ copy_tests( IS_FBCODE or sys.platform == "darwin", "NonABI mode should not be used in fbcode nor on MacOS", ) -class AOTInductorTestNonABICompatibleCpu(TestCase): +class AOTInductorTestNonABICompatibleCpu(AOTITestCase): device = "cpu" abi_compatible = False check_model = check_model @@ -3355,6 +3385,7 @@ copy_tests( "test_runtime_checks_shape_failed": TestFailure( ("non_abi_compatible_cpu",), is_skip=True ), + "test_custom_op_add": TestFailure(("non_abi_compatible_cpu",), is_skip=True), }, ) @@ -3363,7 +3394,7 @@ copy_tests( IS_FBCODE or sys.platform == "darwin", "NonABI mode should not be used in fbcode nor on MacOS", ) -class AOTInductorTestNonABICompatibleCuda(TestCase): +class AOTInductorTestNonABICompatibleCuda(AOTITestCase): device = "cuda" abi_compatible = False check_model = check_model diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 22b1778c2f83..75ef8f1112db 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -47,6 +47,15 @@ class AOTIRunnerUtil: restore_fqn=False, ) + if IS_FBCODE: + from deeplearning.aot_inductor.extern_node_thrift_serializer import ( + thrift_serializer, + ) + + if options is None: + options = {} + options["extern_node_serializer"] = thrift_serializer + with torch.no_grad(): so_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type] diff --git a/third_party/nlohmann.BUILD b/third_party/nlohmann.BUILD index 58fd987bcbdb..64dfbbab2b6e 100644 --- a/third_party/nlohmann.BUILD +++ b/third_party/nlohmann.BUILD @@ -1,20 +1,18 @@ load("@rules_cc//cc:defs.bzl", "cc_library") -cc_library( - name = "nlohmann", - hdrs = glob(["include/**/*.hpp"]), - includes = [ - "/", - ], +cc_library(name = "nlohmann", + includes = ["include"], + deps = ["nlohmann-internal"], visibility = ["//visibility:public"], ) +cc_import(name = "nlohmann-internal", + hdrs = glob(["include/**/*.hpp"]), + visibility = ["//visibility:private"], +) + cc_library( name = "nlohmann_single_include", hdrs = glob(["single_include/nlohmann/*.hpp"]), - includes = [ - "/", - ], visibility = ["//visibility:public"], ) - diff --git a/torch/_export/serde/aoti_schema.py b/torch/_export/serde/aoti_schema.py new file mode 100644 index 000000000000..17d5ceda0ef0 --- /dev/null +++ b/torch/_export/serde/aoti_schema.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import List + +from torch._export.serde.schema import Node + + +@dataclass +class ExternKernelNode: + name: str + node: Node + + +@dataclass +class ExternKernelNodes: + nodes: List[ExternKernelNode] diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 8381b0fdf936..4ea9ec137986 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1797,7 +1797,7 @@ class AotCodeCompiler: with lock: # Currently, this only support serializing extern nodes in fbcode # Eventually, we should also have a serializer for OSS. - if config.is_fbcode() and serialized_extern_kernel_nodes: + if serialized_extern_kernel_nodes: output_json = os.path.splitext(input_path)[0] + ".json" with open(output_json, "w") as f: f.write(serialized_extern_kernel_nodes) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index abd355779f23..8b4f180c5d41 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2111,19 +2111,19 @@ class CppWrapperCpu(WrapperCodeGen): if isinstance(output_args, str): output_args = [output_args] - if config.is_fbcode(): + if V.graph.aot_mode and config.abi_compatible: assert op_overload is not None assert raw_args is not None assert outputs is not None - return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode( + return self.generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( cpp_kernel_key, op_overload, raw_args, output_args, ) else: - return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss( + return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit( buf_name, python_kernel_name, cpp_kernel_name, @@ -2207,7 +2207,7 @@ if (custom_op_wrapper.get() == NULL) { lines += f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(raw_arg, arg_type)});\n" return lines - def generate_extern_kernel_alloc_and_find_schema_if_needed_oss( + def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( self, buf_name: str, python_kernel_name: str, @@ -2220,7 +2220,7 @@ if (custom_op_wrapper.get() == NULL) { raw_args=None, output_args: Optional[List[str]] = None, ): - if V.graph.aot_mode or not config.abi_compatible: + if not config.abi_compatible: # Will update this to use an OSS version ProxyExecutor if cpp_kernel_key not in self.extern_call_ops: self.writeline( @@ -2288,7 +2288,7 @@ if (py_{buf_name}.get() == NULL) {{ ) self.writelines(scope_gil_acquire) - def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode( + def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( self, cpp_kernel_key, op_overload, diff --git a/torch/_inductor/extern_node_serializer.py b/torch/_inductor/extern_node_serializer.py new file mode 100644 index 000000000000..2aaa8552939b --- /dev/null +++ b/torch/_inductor/extern_node_serializer.py @@ -0,0 +1,26 @@ +import json +from typing import List + +from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node +from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder + +from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode + + +def serialize_extern_kernel_node( + extern_kernel_node: inductor_ExternKernelNode, +) -> ExternKernelNode: + assert isinstance(extern_kernel_node.node, Node) + return ExternKernelNode( + name=extern_kernel_node.name, + node=extern_kernel_node.node, + ) + + +def extern_node_json_serializer( + extern_kernel_nodes: List[inductor_ExternKernelNode], +) -> str: + serialized_nodes = ExternKernelNodes( + nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] + ) + return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 14cb68aba892..1656a2887753 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -380,9 +380,15 @@ class GraphLowering(torch.fx.Interpreter): self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment] # See `ProxyExecutor Design Note` in ir.py for more details self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] - self.extern_node_serializer: Optional[ - Callable[[List[ir.ExternKernelNode]], Any] - ] = extern_node_serializer + + from torch._inductor.extern_node_serializer import extern_node_json_serializer + + self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = ( + extern_node_serializer + if config.is_fbcode() and extern_node_serializer + else extern_node_json_serializer + ) + self.current_node: torch.fx.Node = None # type: ignore[assignment] self.lists: Dict[str, List[str]] = {} self.mutated_inputs: Set[str] = set() @@ -1808,11 +1814,7 @@ class GraphLowering(torch.fx.Interpreter): output_code_log.debug("Output code: \n%s", code) serialized_extern_kernel_nodes = None - if ( - config.is_fbcode() - and self.extern_kernel_nodes - and self.extern_node_serializer - ): + if self.extern_kernel_nodes: serialized_extern_kernel_nodes = self.extern_node_serializer( self.extern_kernel_nodes ) diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index e6b356d85735..8d3067d418d5 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -2,8 +2,18 @@ #include #include +#include #include +// TODO: Investigate why this is necessary, but fixes build problems in FRL +#if __has_include("filesystem") +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + namespace torch::inductor { AOTIModelContainerRunner::AOTIModelContainerRunner( @@ -46,6 +56,17 @@ AOTIModelContainerRunner::AOTIModelContainerRunner( get_call_spec_func_ = reinterpret_cast( model_so_->sym("AOTInductorModelContainerGetCallSpec")); + // Hack to find the json file name from the model so file + size_t lastindex = model_so_path.find_last_of("."); + std::string json_filename = model_so_path.substr(0, lastindex) + ".json"; + + if (fs::exists(json_filename)) { + proxy_executor_ = std::make_unique( + json_filename, device_str == "cpu"); + proxy_executor_handle_ = + reinterpret_cast(proxy_executor_.get()); + } + AOTI_RUNTIME_ERROR_CODE_CHECK(create_func_( &container_handle_, num_models, diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.h b/torch/csrc/inductor/aoti_runner/model_container_runner.h index bf5932068b59..4d4b3fffa702 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.h +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.h @@ -3,6 +3,7 @@ #include #include +#include // Forward declare DynamicLibrary namespace at { @@ -75,9 +76,10 @@ class TORCH_API AOTIModelContainerRunner { AOTInductorModelContainerHandle container_handle_ = nullptr; - // TODO: need an OSS proxy executor implementation. For now, - // proxy_executor_handle_ will always be nullptr. - AOTIProxyExecutorHandle proxy_executor_handle_ = nullptr; + AOTIProxyExecutorHandle proxy_executor_handle_; + + private: + std::unique_ptr proxy_executor_; }; } // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp new file mode 100644 index 000000000000..fea21f9d304b --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp @@ -0,0 +1,262 @@ +#include +#include +#include + +#include + +namespace { +at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) { + return reinterpret_cast(handle); +} +} // namespace + +namespace torch::aot_inductor { + +void OSSProxyExecutor::prefill_stack_with_static_arguments( + int index, + at::TypePtr schema_arg_type, + const nlohmann::json& serialized_arg, + OpKernel& op_kernel) { + auto& stack = op_kernel.stack_; + auto& dynamic_args = op_kernel.dynamic_args_; + + TORCH_CHECK(serialized_arg.size() == 1); + std::string serialized_arg_type = serialized_arg.begin().key(); + auto& serialized_arg_val = serialized_arg.begin().value(); + + switch (schema_arg_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_arg_type == "as_tensor"); + stack.emplace_back(); + dynamic_args.emplace_back( + index, DynamicArgType::TensorType, 1, std::move(serialized_arg_val)); + break; + } + // TODO: handle the other input types + default: + TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type); + } +} + +// Populates op_kernel.stack_, op_kernel.dynamic_args_ +void OSSProxyExecutor::get_input_info_from_serialized( + const std::vector& schema_args, + const nlohmann::json& serialized_node, + OpKernel& op_kernel) { + int index = 0; + for (const auto& named_argument : serialized_node["inputs"]) { + const auto& arg = named_argument["arg"]; + auto& schema_arg = schema_args[index]; + + prefill_stack_with_static_arguments( + index++, schema_arg.real_type(), arg, op_kernel); + } + + // TODO: prefill default values +} + +// Populates op_kernel.outputs_ +void OSSProxyExecutor::get_output_info_from_serialized( + const std::vector& schema_returns, + const nlohmann::json& serialized_node, + OpKernel& op_kernel) { + std::vector& outputs = op_kernel.outputs_; + + TORCH_CHECK( + schema_returns.size() == serialized_node["outputs"].size(), + "Serialized node doesn't match op's schema outputs."); + + size_t output_index = 0; + for (const auto& serialized_output : serialized_node["outputs"]) { + TORCH_CHECK(serialized_output.size() == 1); + std::string serialized_output_type = serialized_output.begin().key(); + auto& serialized_output_val = serialized_output.begin().value(); + + auto& schema_return = schema_returns[output_index]; + at::TypePtr schema_return_type = schema_return.real_type(); + + switch (schema_return_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK( + serialized_output_type == "as_tensor", + serialized_node["target"], + " got serialized_output_type of ", + serialized_output_type); + outputs.emplace_back( + output_index, + DynamicArgType::TensorType, + 1, + serialized_output_type); + break; + } + case c10::TypeKind::ListType: { + if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK( + serialized_output_type == "as_tensors", + serialized_node["target"], + " got serialized_output_type of ", + serialized_output_type); + outputs.emplace_back( + output_index, + DynamicArgType::ListTensorType, + serialized_output_val.size(), + serialized_output_type); + } else { + TORCH_CHECK( + false, + "Unsupported return list type ", + schema_return_type->repr_str()); + } + break; + } + default: { + TORCH_CHECK( + false, "Unsupported return type ", schema_return_type->repr_str()); + } + } + + output_index++; + } +} + +OSSProxyExecutor::OSSProxyExecutor(const std::string& json_path, bool is_cpu) { + if (is_cpu) { + device_ = std::make_unique(c10::DeviceType::CPU); + } else { + int device_idx = -1; + device_ = std::make_unique(c10::DeviceType::CUDA, device_idx); + } + + std::string extern_kernel_nodes_serialized; + + std::ifstream json_file(json_path); + TORCH_CHECK(json_file.is_open()); + + // Parse file into a json object + nlohmann::json json_obj; + json_file >> json_obj; + + // Access data + for (auto const& serialized_extern_node : json_obj["nodes"]) { + auto const& serialized_node = serialized_extern_node["node"]; + + const std::string& target = serialized_node["target"]; + + std::string opName; + std::string overloadName; + size_t pos = target.find('.'); + if (pos == std::string::npos) { + opName = target; + overloadName = ""; + } else { + // There should be no more periods + size_t pos2 = target.find('.', pos); + TORCH_CHECK(pos2 == std::string::npos); + + opName = target.substr(0, pos); + overloadName = target.substr(pos + 1, target.length() - pos); + } + + c10::OperatorHandle op_handle = + c10::Dispatcher::singleton().findSchemaOrThrow( + opName.c_str(), overloadName.c_str()); + const c10::FunctionSchema& schema = op_handle.schema(); + + const auto& schema_args = schema.arguments(); + const auto& schema_returns = schema.returns(); + + OpKernel op_kernel(target, op_handle); + get_input_info_from_serialized(schema_args, serialized_node, op_kernel); + get_output_info_from_serialized(schema_returns, serialized_node, op_kernel); + + op_kernels_.emplace_back(std::move(op_kernel)); + } +} + +void OSSProxyExecutor::call_function( + int extern_node_index, + int num_ints, + int64_t* flatten_int_args, + int num_tensors, + AtenTensorHandle* flatten_tensor_args) { + TORCH_CHECK( + extern_node_index < static_cast(op_kernels_.size()), + "Invalid extern node index"); + OpKernel& op_kernel = op_kernels_[extern_node_index]; + + std::vector stack = op_kernel.stack_; + auto& dynamic_args = op_kernel.dynamic_args_; + + int tensor_id = 0; + int int_id = 0; + for (auto& dynamic_arg : dynamic_args) { + int arg_index = dynamic_arg.arg_index; + DynamicArgType dynamic_arg_type = dynamic_arg.arg_type; + int length = dynamic_arg.length; + + if (length == 0) { + continue; + } + + switch (dynamic_arg_type) { + case DynamicArgType::TensorType: { + at::Tensor* tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + stack[arg_index] = *tensor; + break; + } + // TODO: handle other dynamic arg types + default: + TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type); + } + } + + int num_output_tensors = op_kernel.num_output_tensors(); + TORCH_CHECK( + tensor_id == num_tensors - num_output_tensors, + "Mismatch between tensors consumed and num of input tensor, got tensor_id = .", + tensor_id, + ", expected num = ", + num_tensors - num_output_tensors); + TORCH_CHECK( + int_id == num_ints, + "Mismatch between ints consumed and num_ints, got int_id = ", + int_id, + ", num_ints = ", + num_ints); + + // Call the op with the prepared stack. + const c10::OperatorHandle& op = op_kernel.op_handle_; + op.callBoxed(stack); + + const c10::FunctionSchema& schema = op.schema(); + const auto& schema_returns = schema.returns(); + + TORCH_CHECK(op_kernel.outputs_.size() == stack.size()); + // TODO: what about optional outputs? This assert may not hold + TORCH_CHECK(stack.size() == schema_returns.size()); + + int index = 0; + for (const auto& schema_return : schema_returns) { + if (schema_return.type()->kind() == c10::TypeKind::TensorType) { + at::Tensor* tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = stack[index++].toTensor(); + // TODO: handle tensor list returns + } else { + TORCH_CHECK( + false, + "NYI: Unsupported return type for schema: ", + schema_return.type()->repr_str()); + } + } + + TORCH_CHECK( + tensor_id == num_tensors, + "Mismatch between tensors consumed and num_tensors, got tensor_id = ", + tensor_id, + ", expected num = ", + num_tensors); +} + +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h new file mode 100644 index 000000000000..dcdaeba59d38 --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h @@ -0,0 +1,101 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace torch::aot_inductor { + +enum class DynamicArgType : int { + TensorType = 0, + ListTensorType = 1, + ListOptionalTensorType = 2, + IntType = 3, + ListIntType = 4, +}; + +inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) { + os << static_cast(arg_type); + return os; +} + +inline bool isTensorType(DynamicArgType arg_type) { + return arg_type == DynamicArgType::TensorType || + arg_type == DynamicArgType::ListTensorType || + arg_type == DynamicArgType::ListOptionalTensorType; +} + +struct DynamicArg { + DynamicArg( + int arg_index, + DynamicArgType arg_type, + int length, + nlohmann::json serialized_arg_val) + : arg_index(arg_index), + arg_type(arg_type), + length(length), + serialized_arg_val(std::move(serialized_arg_val)) {} + int arg_index; + DynamicArgType arg_type; + int length; + nlohmann::json serialized_arg_val; +}; + +struct OpKernel { + OpKernel(const std::string& target, const c10::OperatorHandle& op_handle) + : target_(target), op_handle_(op_handle) {} + + std::string target_; + c10::OperatorHandle op_handle_; + std::vector dynamic_args_; + std::vector outputs_; + std::vector stack_; + + int num_output_tensors() const { + int num_output_tensors = 0; + for (const auto& output : outputs_) { + if (isTensorType(output.arg_type)) { + num_output_tensors += output.length; + } + } + return num_output_tensors; + } +}; + +class OSSProxyExecutor : public ProxyExecutor { + public: + explicit OSSProxyExecutor(const std::string& json_path, bool is_cpu); + + void call_function( + int extern_node_index, + int num_ints, + int64_t* flatten_int_args, + int num_tensors, + AtenTensorHandle* flatten_tensor_args) override; + + private: + void prefill_stack_with_static_arguments( + int index, + at::TypePtr schema_arg_type, + const nlohmann::json& thrift_arg, + OpKernel& op_kernel); + + void get_input_info_from_serialized( + const std::vector& schema_args, + const nlohmann::json& serialized_node, + OpKernel& op_kernel); + + void get_output_info_from_serialized( + const std::vector& schema_returns, + const nlohmann::json& serialized_node, + OpKernel& op_kernel); + + std::vector op_kernels_; + std::unique_ptr device_; +}; + +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 99fb1e6d2bfd..0f82891f1fda 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include