[AOTInductor][Reland] Proxy Executor for Extern Fallback kernels (#107279) (#108350)

Summary:

This is a prototype for running extern fallback kernels with a host side proxy executor.

Sample of generated cpp wrapper call:
```
        at::Tensor buf0;  // output buffer
        void* tensor_args_var_0[] = {&arg0_1, &arg0_1, &arg1_1, &arg0_1, &arg1_1, &buf0};
        int64_t int_args_var_1[] = {81, 81, 7, 7, 7, 81};
        proxy_executor->call_function("buf0", int_args_var_1, tensor_args_var_0);
```

- In my current implementation, proxy executor interprets the raw pointers according to the ops schema.
This assumes that custom op MUST have a valid schema registered to Dispatcher. (I would like to validate this assumption)
- I am using callboxed() API of the custom kernels. This is inevitable, as we wish to have a single call_function API for all possible custom kernels.

- These are all the input argument types I have support so far.
       union Argument {
         # Bool value does not matter
         1: bool asNone;
         2: TensorArgument asTensor;
         3: list<TensorArgument> asTensors;
         5: i64 asInt;
         7: list<i64> asInts;
         8: double asFloat;
         9: list<double> asFloats;
         10: string asString;
         10.5: list<string> asStrings;
         11: SymIntArgument asSymInt;
         12: list<SymIntArgument> asSymInts;
         13: ScalarType asScalarType;
         14: MemoryFormat asMemoryFormat;
         15: Layout asLayout;
         16: Device asDevice;
         17: bool asBool;
         18: list<bool> asBools;
       }

- Need a policy for handling unpopulated argument with default values. Here are the options, and it has BC  implications.
1. requires exported fx graph to explicitly populate default values, if users doesn't specify.
2. requires cpp wrapper to explicitly populate default values, if fx graph doesn't specify.
3. Proxy executor look up from opSchema for default values.

For fixing T162112344

Test Plan:
frontend:
buck2 run mode/dev-sand mode/inplace -c fbcode.enable_gpu_sections=True sigmoid/frontend:export_main

test:
 buck2 run mode/dev-sand //deeplearning/aot_inductor/test:test_custom_ops

backend:
buck2 run mode/dev-nosan //deeplearning/aot_inductor/fb:main

buck2 test 'fbcode//mode/opt' fbcode//caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark -- --exact 'caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark - test_aot_inductor_benchmark_cmf30x (caffe2.torch.fb.model_transform.experimental.benchmark.test.test_aot_inductor_benchmark.AOTInductorBenchmark)'

Reviewed By: suo

Differential Revision: D48747417

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108350
Approved by: https://github.com/izaitsevfb
This commit is contained in:
Sherlock Huang
2023-09-02 17:14:10 +00:00
committed by PyTorch MergeBot
parent b9fc6d7ded
commit b9dfdc091b
16 changed files with 372 additions and 17 deletions

View File

@ -904,6 +904,7 @@ struct TORCH_API ListType
static ListTypePtr ofComplexDoubles();
static ListTypePtr ofBools();
static ListTypePtr ofStrings();
static ListTypePtr ofNumbers();
private:
ListType(TypePtr elem) : SingleElementType(std::move(elem)) {}

View File

@ -268,6 +268,10 @@ ListTypePtr ListType::ofStrings() {
static auto value = ListType::create(StringType::get());
return value;
}
ListTypePtr ListType::ofNumbers() {
static auto value = ListType::create(NumberType::get());
return value;
}
TypePtr OptionalType::get(TypePtr inner) {
static ska::flat_hash_map<TypePtr, TypePtr> containerTypePtrs;

View File

@ -58,13 +58,17 @@ TEST(AotInductorTest, BasicTest) {
reinterpret_cast<AOTInductorTensorHandle>(inputs.data());
AOTInductorTensorHandle outputs_handle =
reinterpret_cast<AOTInductorTensorHandle>(outputs.data());
AOTInductorProxyExecutorHandle proxy_executor_handle = nullptr;
AOT_INDUCTOR_ERROR_CHECK(AOTInductorModelContainerRun(
container_handle,
inputs_handle,
inputs.size(),
outputs_handle,
outputs.size(),
stream_handle));
stream_handle,
proxy_executor_handle));
ASSERT_TRUE(torch::allclose(results_ref, outputs[0]));
AOT_INDUCTOR_ERROR_CHECK(AOTInductorModelContainerDelete(container_handle));

View File

@ -525,6 +525,8 @@ class GraphModuleSerializer:
)
def serialize_input(self, arg) -> Argument:
import torch._inductor.ir as inductor_ir
if isinstance(arg, torch.fx.Node):
if arg.op == "get_attr":
assert isinstance(arg.target, str)
@ -545,6 +547,13 @@ class GraphModuleSerializer:
return Argument.create(as_sym_bool=SymBoolArgument.create(as_name=arg.name))
else:
return Argument.create(as_tensor=TensorArgument(name=arg.name))
elif isinstance(arg, (inductor_ir.InputBuffer, inductor_ir.ComputedBuffer)):
# Other branches are for arguments in fx node.
# This is a special branch for handling buffers (representing tensor arguments)
# for inductor's ExternalFallbackNode
# export_extern_kernel_node() is using this function to serialize arguments
assert arg.name is not None, "Input buffer must have valid name"
return Argument.create(as_tensor=TensorArgument(name=arg.name))
elif isinstance(arg, bool):
return Argument.create(as_bool=arg)
elif isinstance(arg, str):
@ -594,7 +603,8 @@ class GraphModuleSerializer:
self.graph_state.constants[a.name] = attr
arguments.append(TensorArgument(name=a.name))
return Argument.create(as_tensors=arguments)
elif any(isinstance(a, torch.fx.Node) for a in arg):
elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg):
# list of optional tensors
def serialize_optional_tensor_args(a):
if a is None:
return OptionalTensorArgument.create(as_none=())
@ -605,6 +615,23 @@ class GraphModuleSerializer:
return Argument.create(
as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
)
elif all(isinstance(a, (inductor_ir.InputBuffer, inductor_ir.ComputedBuffer)) for a in arg):
# list of tensors
return Argument.create(
as_tensors=[TensorArgument(name=a.name) for a in arg],
)
elif all(isinstance(a, (inductor_ir.InputBuffer, inductor_ir.ComputedBuffer, type(None))) for a in arg):
# list of optional tensors
def serialize_optional_tensor_args(a):
if a is None:
return OptionalTensorArgument.create(as_none=())
elif isinstance(a, torch._inductor.ir.InputBuffer):
return OptionalTensorArgument.create(as_tensor=a.name)
else:
raise SerializeError(f"Unsupported list/tuple argument: {a}")
return Argument.create(
as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
)
else:
raise SerializeError(f"Unsupported list/tuple argument type: {type(arg[0])}")
elif isinstance(arg, torch.dtype):

View File

@ -49,7 +49,13 @@ def aot_compile(
gm,
example_inputs,
config_patches=options,
)()
)
# AOTInductor returns result as a string, not callable
# Maybe this check is not neded?
if callable(result):
result = result()
lib_path = result[0] if isinstance(result, (list, tuple)) else result
return lib_path

View File

@ -943,7 +943,7 @@ class AotCodeCache:
clear = staticmethod(cache.clear)
@classmethod
def compile(cls, graph, source_code, cuda):
def compile(cls, graph, source_code, serialized_extern_kernel_nodes, cuda):
# TODO: update cpp_compile_command for different platforms
picked_vec_isa = invalid_vec_isa if cuda else pick_vec_isa()
cpp_command = repr(
@ -963,6 +963,13 @@ class AotCodeCache:
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
# Currently, this only support serializing extern nodes in fbcode
# Eventually, we should also have a serializer for OSS.
if config.is_fbcode() and serialized_extern_kernel_nodes:
output_json = os.path.splitext(input_path)[0] + ".json"
with open(output_json, "w") as f:
f.write(serialized_extern_kernel_nodes)
output_so = os.path.splitext(input_path)[0] + ".so"
if not os.path.exists(output_so):

View File

@ -2,6 +2,7 @@
#include <torch/csrc/inductor/aot_inductor_model_container.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <iostream>
#include <torch/csrc/inductor/proxy_executor.h>
#include <stdexcept>
#include <vector>
@ -50,7 +51,8 @@ AOTInductorError AOTInductorModelContainerRun(
size_t num_inputs,
AOTInductorTensorHandle outputs_handle,
size_t num_outputs,
AOTInductorStreamHandle stream_handle) {
AOTInductorStreamHandle stream_handle,
AOTInductorProxyExecutorHandle proxy_executor_handle) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
@ -70,8 +72,11 @@ AOTInductorError AOTInductorModelContainerRun(
}
auto stream = reinterpret_cast<cudaStream_t>(stream_handle);
torch::aot_inductor::ProxyExecutor* proxy_executor = reinterpret_cast<torch::aot_inductor::ProxyExecutor*>(proxy_executor_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ container->run(input_tensors, output_tensors, stream); })
{ container->run(input_tensors, output_tensors, stream, proxy_executor); })
}
AOTInductorError AOTInductorModelContainerGetNumInputs(

View File

@ -110,10 +110,16 @@ VECTORIZABLE_RTYPES = {
}
PYTHON_TO_CPP = {
"Tensor": "at::Tensor",
"int": "long",
"float": "double",
"bool": "bool",
"str": "std::string",
"ScalarType": "c10::ScalarType",
"MemoryFormat": "at::MemoryFormat",
"Layout": "at::Layout",
"Device": "at::Device",
"number": "at::Scalar",
}
CONTAINER_PYTHON_TO_CPP = {

View File

@ -472,6 +472,8 @@ class WrapperCodeGen(CodeGen):
cpp_op_schema,
cpp_kernel_key,
cpp_kernel_overload_name="",
op_overload=None,
raw_args=None,
):
self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})")
@ -916,6 +918,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
self.supports_intermediate_hooks = False
self.outputs_need_copy = set()
self.resized_outputs = {}
self.kernel_callsite_id = count()
from .cpp import cexpr
@ -975,7 +978,8 @@ class CppWrapperCodeGen(WrapperCodeGen):
void AOTInductorModel::run_impl(
const std::vector<at::Tensor>& args,
std::vector<at::Tensor>& outputs,
cudaStream_t stream) {
cudaStream_t stream,
ProxyExecutor* proxy_executor) {
"""
)
else:
@ -1298,6 +1302,112 @@ class CppWrapperCodeGen(WrapperCodeGen):
f"{self.codegen_tensor_option(device, dtype)};"
)
def generate_extern_kernel_args_decl_if_needed(
self, op_overload, raw_args, output_args
):
arg_types = [x.real_type for x in op_overload._schema.arguments]
return_types = [x.type for x in op_overload._schema.returns]
new_tensor_args = []
new_int_args = []
def fill_args(arg, arg_type):
static_arg_types = (
torch.FloatType,
torch.BoolType,
torch.StringType,
torch.Type,
torch.DeviceObjType,
)
if isinstance(arg_type, torch.TensorType):
assert isinstance(arg, (ir.InputBuffer, ir.ComputedBuffer))
new_tensor_args.append(f"&{arg.name}")
elif isinstance(arg_type, (torch.IntType, torch.SymIntType)):
# int or SymInt
assert isinstance(arg, int)
new_int_args.append(str(arg))
elif isinstance(arg_type, torch.NumberType):
# Scalar of type int
assert isinstance(arg, (int, float, bool))
# Only treat int Scalar as dynamic
if isinstance(arg, int):
new_int_args.append(str(arg))
elif isinstance(arg_type, torch.ListType):
assert isinstance(arg, (list, tuple))
# List[Tensor]
if isinstance(arg_type.getElementType(), torch.TensorType):
new_tensor_args.extend([f"&{a.name}" for a in arg])
# List[Optional[Tensor]]
elif isinstance(
arg_type.getElementType(), torch.OptionalType
) and isinstance(
arg_type.getElementType().getElementType(), torch.TensorType
):
new_tensor_args.extend([f"&{a.name}" for a in arg if a is not None])
# List [int] or List[SymInt]
elif isinstance(
arg_type.getElementType(), (torch.IntType, torch.SymIntType)
):
new_int_args.extend([str(a) for a in arg])
# List[Scalar]
elif isinstance(arg_type.getElementType(), torch.NumberType):
# Only treat int Scalar as dynamic
is_int_type = [isinstance(a, int) for a in arg]
if any(is_int_type):
assert all(
is_int_type
), "AOTInductor only supports int scalars of the same type"
new_int_args.extend([str(a) for a in arg])
else:
assert isinstance(
arg_type.getElementType(), static_arg_types
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
else:
assert isinstance(
arg_type, static_arg_types
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
for arg, arg_type in zip(raw_args, arg_types):
if arg is not None:
if isinstance(arg_type, torch.OptionalType):
fill_args(arg, arg_type.getElementType())
else:
fill_args(arg, arg_type)
def fill_output_arg(arg, return_type):
if isinstance(return_type, torch.TensorType):
self.writeline(f"at::Tensor {arg}; // output buffer")
new_tensor_args.append(f"&{output_arg}")
elif isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.TensorType
):
# TODO: handle tensor list return type
raise NotImplementedError("NYI support for return type: List[Tensor]")
elif isinstance(return_type, torch.SymIntType):
raise NotImplementedError("NYI support for return type: SymInt")
elif isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.SymIntType
):
raise NotImplementedError("NYI support for return type: List[SymInt]")
else:
raise AssertionError(f"Unsupport return type found: {return_type}")
assert (
len(output_args) == 1
), "Support for multiple returns is not implemented yet"
for output_arg, return_type in zip(output_args, return_types):
# TODO: check schema here
# assume it's a tensor now
if output_arg is not None:
if isinstance(return_type, torch.OptionalType):
fill_output_arg(output_arg, return_type.getElementType())
else:
fill_output_arg(output_arg, return_type)
return new_tensor_args, new_int_args
def generate_extern_kernel_alloc_and_find_schema_if_needed(
self,
name,
@ -1306,6 +1416,37 @@ class CppWrapperCodeGen(WrapperCodeGen):
cpp_op_schema,
cpp_kernel_key,
cpp_kernel_overload_name="",
op_overload=None,
raw_args=None,
):
if config.is_fbcode():
assert op_overload is not None
assert raw_args is not None
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
name,
cpp_kernel_key,
op_overload,
raw_args,
)
else:
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
name,
kernel,
codegen_args,
cpp_op_schema,
cpp_kernel_key,
cpp_kernel_overload_name,
)
def generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
self,
name,
kernel,
codegen_args,
cpp_op_schema,
cpp_kernel_key,
cpp_kernel_overload_name="",
):
if cpp_kernel_key not in self.extern_call_ops:
self.writeline(
@ -1321,6 +1462,43 @@ class CppWrapperCodeGen(WrapperCodeGen):
f"auto {name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});"
)
def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
self,
name,
cpp_kernel_key,
op_overload,
raw_args, # contains both args and flatten kwargs
):
output_args = [name]
(
tensor_call_args,
int_call_args,
) = self.generate_extern_kernel_args_decl_if_needed(
op_overload, raw_args, output_args
)
tensor_args_var = f"tensor_args_var_{next(self.kernel_callsite_id)}"
tensor_call_args_str = ", ".join(tensor_call_args)
self.writeline(f"void* {tensor_args_var}[] = {{{tensor_call_args_str}}};")
int_args_var = f"int_args_var_{next(self.kernel_callsite_id)}"
int_call_args_str = ", ".join(int_call_args)
self.writeline(f"int64_t {int_args_var}[] = {{{int_call_args_str}}};")
extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
self.writeline(
f"proxy_executor->call_function("
f"{extern_kernel_node_index}, "
f"{len(int_call_args)}, "
f"{int_args_var}, "
f"{len(tensor_call_args)}, "
f"{tensor_args_var});"
)
self.extern_call_ops.add(cpp_kernel_key)
def val_to_arg_str(self, val):
from .cpp import DTYPE_TO_ATEN
@ -1353,7 +1531,6 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
def __init__(self):
super().__init__()
self.kernel_callsite_id = count()
self.arg_var_id = count()
self.cuda = True

View File

@ -39,6 +39,7 @@ from .fx_passes.joint_graph import joint_graph_passes
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
from .fx_passes.pre_grad import pre_grad_passes
from .graph import GraphLowering
from .ir import ExternKernelNode
from .pattern_matcher import clone_graph
from .utils import get_dtype_size, has_incompatible_cudagraph_ops
from .virtualized import V
@ -299,6 +300,7 @@ def compile_fx_inner(
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
user_visible_outputs: FrozenSet[str] = frozenset(),
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
):
"""
Inductor API that compiles a single graph.
@ -340,6 +342,7 @@ def compile_fx_inner(
"is_inference": is_inference,
"user_visible_outputs": user_visible_outputs,
"layout_opt": layout_opt,
"extern_node_serializer": extern_node_serializer,
}
compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
@ -493,6 +496,7 @@ def fx_codegen_and_compile(
is_inference: bool = False,
user_visible_outputs: FrozenSet[str] = frozenset(),
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
) -> CompiledFxGraph:
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
@ -549,6 +553,7 @@ def fx_codegen_and_compile(
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
user_visible_outputs=user_visible_outputs,
extern_node_serializer=extern_node_serializer,
)
with V.set_graph_handler(graph): # type: ignore[call-arg]
graph.run(*example_inputs)
@ -864,11 +869,16 @@ def compile_fx_aot(
"aot_inductor_output_path": code_hash(model_.code),
}
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
with mock.patch.object(_in_aot_compilation, "value", True):
return compile_fx(
model_,
example_inputs_,
inner_compile=functools.partial(inner_compile, aot_mode=True),
inner_compile=functools.partial(
inner_compile,
aot_mode=True,
extern_node_serializer=extern_node_serializer,
),
config_patches=config_patches,
)

View File

@ -7,7 +7,7 @@ import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import DefaultDict, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple
import sympy
@ -165,6 +165,7 @@ class GraphLowering(torch.fx.Interpreter):
aot_mode=False,
user_visible_outputs=frozenset(),
layout_opt=None,
extern_node_serializer=None,
):
super().__init__(gm)
@ -196,6 +197,11 @@ class GraphLowering(torch.fx.Interpreter):
self.mutated_buffers: Set[str] = set()
self.inplaced_to_remove: Set[str] = set()
self.wrapper_code: Optional[WrapperCodeGen] = None
# See `ProxyExecutor Design Note` in ir.py for more details
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
self.extern_node_serializer: Optional[
Callable[[List[ir.ExternKernelNode]], Any]
] = extern_node_serializer
self.current_node: Optional[torch.fx.Node] = None
self.num_static_inputs = num_static_inputs
self.lists: Dict[str, List[str]] = {}
@ -963,8 +969,24 @@ class GraphLowering(torch.fx.Interpreter):
code, linemap = self.codegen()
output_code_log.debug("Output code: \n%s", code)
serialized_extern_kernel_nodes = None
if (
config.is_fbcode()
and self.extern_kernel_nodes
and self.extern_node_serializer
):
serialized_extern_kernel_nodes = self.extern_node_serializer(
self.extern_kernel_nodes
)
output_code_log.debug(
"Serialized Extern Kernel Nodes: \n%s",
serialized_extern_kernel_nodes,
)
# Directly return the file path with the compiled code
return AotCodeCache.compile(self, code, cuda=self.cuda)
return AotCodeCache.compile(
self, code, serialized_extern_kernel_nodes, cuda=self.cuda
)
else:
return self.compile_to_module().call

View File

@ -28,11 +28,14 @@ from unittest.mock import patch
import sympy
from sympy import Expr, Integer
import torch._export.serde.schema as export_schema
import torch._logging
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.utils import identity
from torch._export.serde.serialize import GraphModuleSerializer
from torch._prims_common import (
compute_required_storage_length,
is_boolean_dtype,
@ -3566,6 +3569,12 @@ class DynamicScalar(IRNode):
return ()
@dataclasses.dataclass
class ExternKernelNode:
name: str
node: export_schema.Node
@dataclasses.dataclass
class FallbackKernel(ExternKernelAlloc):
def __init__(
@ -3585,6 +3594,8 @@ class FallbackKernel(ExternKernelAlloc):
)
self.use_cpp_op_schema = False
self.op_overload = kernel
op_overload_packet = (
kernel._overloadpacket
if isinstance(kernel, torch._ops.OpOverload)
@ -3692,8 +3703,48 @@ class FallbackKernel(ExternKernelAlloc):
return devices[0]
return None
# ProxyExecutor Design Note
# We export the ExternFallbackNodes (for custom ops) into a serialized file
# and run it with a host side proxy executor to address the ABI problem
# This is currently only implemented for fbcode. Eventually, we will also make this work for OSS.
# Detailed design doc can be found at
# https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing
def export_extern_kernel_node(self):
args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
ordered_kwargs = [
kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel
]
serializer = GraphModuleSerializer(None, None, None)
named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs)
# TODO: only single output is supported
output_arguments = [
export_schema.Argument.create(
as_tensor=export_schema.TensorArgument(name=self.get_name())
)
]
node = ExternKernelNode(
name=self.get_name(),
node=export_schema.Node(
target=self.kernel,
inputs=named_arguments,
outputs=output_arguments,
metadata={},
),
)
V.graph.extern_kernel_nodes.append(node)
return [*args, *ordered_kwargs]
def codegen(self, wrapper):
if self.use_cpp_op_schema:
exported_args = None
if config.is_fbcode():
exported_args = self.export_extern_kernel_node()
args = [*self.codegen_args(), *self.codegen_kwargs()]
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
@ -3702,6 +3753,8 @@ class FallbackKernel(ExternKernelAlloc):
self.cpp_op_schema,
self.cpp_kernel_key,
self.cpp_kernel_overlad_name,
self.op_overload,
exported_args,
)
else:
super().codegen(wrapper)

View File

@ -44,6 +44,9 @@ using AOTInductorStreamHandle = AOTInductorStreamOpaque*;
struct AOTInductorTensorOpaque {};
using AOTInductorTensorHandle = AOTInductorTensorOpaque*;
struct AOTInductorProxyExecutorOpaque {};
using AOTInductorProxyExecutorHandle = AOTInductorProxyExecutorOpaque*;
extern "C" {
// Creates an AOTInductor model container. The parameter num_models
// specifies the number of model instances that may be run concurrently for
@ -63,7 +66,8 @@ AOTInductorError AOTInductorModelContainerRun(
size_t num_inputs,
AOTInductorTensorHandle outputs_handle,
size_t num_outputs,
AOTInductorStreamHandle stream_handle);
AOTInductorStreamHandle stream_handle,
AOTInductorProxyExecutorHandle proxy_executor_handle);
// Retrieves the number of inputs for the model.
AOTInductorError AOTInductorModelContainerGetNumInputs(

View File

@ -7,6 +7,7 @@
#include <ATen/ATen.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/inductor/proxy_executor.h>
#define AOT_VECTOR_SIZE_CHECK(vec, expected_size) \
{ \
@ -51,12 +52,13 @@ class AOTInductorModelBase {
void run(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
cudaStream_t stream) {
cudaStream_t stream,
ProxyExecutor* proxy_executor = nullptr) {
AOT_VECTOR_SIZE_CHECK(inputs, num_inputs());
AOT_VECTOR_SIZE_CHECK(outputs, num_outputs());
auto* model = static_cast<Model*>(this);
model->run_impl(inputs, outputs, stream);
model->run_impl(inputs, outputs, stream, proxy_executor);
C10_CUDA_CHECK(cudaEventRecord(run_finished_, stream));
}
@ -178,7 +180,8 @@ class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> {
void run_impl(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
cudaStream_t stream);
cudaStream_t stream,
ProxyExecutor* proxy_executor = nullptr);
static std::unique_ptr<AOTInductorModel> Create() {
return std::make_unique<AOTInductorModel>();

View File

@ -5,6 +5,7 @@
#include <shared_mutex>
#include <torch/csrc/inductor/aot_inductor_model.h>
#include <torch/csrc/inductor/proxy_executor.h>
namespace torch {
namespace aot_inductor {
@ -56,12 +57,13 @@ class AOTInductorModelContainer {
void run(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
cudaStream_t stream) {
cudaStream_t stream,
ProxyExecutor* proxy_executor) {
auto* model = get_available_model();
try {
AOT_VECTOR_SIZE_CHECK(inputs, num_inputs());
AOT_VECTOR_SIZE_CHECK(outputs, num_outputs());
model->run(inputs, outputs, stream);
model->run(inputs, outputs, stream, proxy_executor);
} catch (...) {
std::lock_guard lk(models_mutex_);
available_models_.push_back(model);

View File

@ -0,0 +1,24 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <string>
namespace torch {
namespace aot_inductor {
class TORCH_API ProxyExecutor : public torch::CustomClassHolder {
public:
ProxyExecutor() {}
virtual ~ProxyExecutor() {}
virtual void call_function(
int extern_node_index,
int num_ints,
int64_t* flatten_int_args,
int num_tensors,
void** flatten_tensor_args) = 0;
};
} // namespace aot_inductor
} // namespace torch