mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This is a prototype for running extern fallback kernels with a host side proxy executor. Sample of generated cpp wrapper call: ``` at::Tensor buf0; // output buffer void* tensor_args_var_0[] = {&arg0_1, &arg0_1, &arg1_1, &arg0_1, &arg1_1, &buf0}; int64_t int_args_var_1[] = {81, 81, 7, 7, 7, 81}; proxy_executor->call_function("buf0", int_args_var_1, tensor_args_var_0); ``` - In my current implementation, proxy executor interprets the raw pointers according to the ops schema. This assumes that custom op MUST have a valid schema registered to Dispatcher. (I would like to validate this assumption) - I am using callboxed() API of the custom kernels. This is inevitable, as we wish to have a single call_function API for all possible custom kernels. - These are all the input argument types I have support so far. union Argument { # Bool value does not matter 1: bool asNone; 2: TensorArgument asTensor; 3: list<TensorArgument> asTensors; 5: i64 asInt; 7: list<i64> asInts; 8: double asFloat; 9: list<double> asFloats; 10: string asString; 10.5: list<string> asStrings; 11: SymIntArgument asSymInt; 12: list<SymIntArgument> asSymInts; 13: ScalarType asScalarType; 14: MemoryFormat asMemoryFormat; 15: Layout asLayout; 16: Device asDevice; 17: bool asBool; 18: list<bool> asBools; } - Need a policy for handling unpopulated argument with default values. Here are the options, and it has BC implications. 1. requires exported fx graph to explicitly populate default values, if users doesn't specify. 2. requires cpp wrapper to explicitly populate default values, if fx graph doesn't specify. 3. Proxy executor look up from opSchema for default values. For fixing T162112344 Test Plan: frontend: buck2 run mode/dev-sand mode/inplace -c fbcode.enable_gpu_sections=True sigmoid/frontend:export_main test: buck2 run mode/dev-sand //deeplearning/aot_inductor/test:test_custom_ops backend: buck2 run mode/dev-nosan //deeplearning/aot_inductor/fb:main buck2 test 'fbcode//mode/opt' fbcode//caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark -- --exact 'caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark - test_aot_inductor_benchmark_cmf30x (caffe2.torch.fb.model_transform.experimental.benchmark.test.test_aot_inductor_benchmark.AOTInductorBenchmark)' Reviewed By: suo Differential Revision: D48747417 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108350 Approved by: https://github.com/izaitsevfb
This commit is contained in:
committed by
PyTorch MergeBot
parent
b9fc6d7ded
commit
b9dfdc091b
@ -904,6 +904,7 @@ struct TORCH_API ListType
|
||||
static ListTypePtr ofComplexDoubles();
|
||||
static ListTypePtr ofBools();
|
||||
static ListTypePtr ofStrings();
|
||||
static ListTypePtr ofNumbers();
|
||||
|
||||
private:
|
||||
ListType(TypePtr elem) : SingleElementType(std::move(elem)) {}
|
||||
|
@ -268,6 +268,10 @@ ListTypePtr ListType::ofStrings() {
|
||||
static auto value = ListType::create(StringType::get());
|
||||
return value;
|
||||
}
|
||||
ListTypePtr ListType::ofNumbers() {
|
||||
static auto value = ListType::create(NumberType::get());
|
||||
return value;
|
||||
}
|
||||
|
||||
TypePtr OptionalType::get(TypePtr inner) {
|
||||
static ska::flat_hash_map<TypePtr, TypePtr> containerTypePtrs;
|
||||
|
@ -58,13 +58,17 @@ TEST(AotInductorTest, BasicTest) {
|
||||
reinterpret_cast<AOTInductorTensorHandle>(inputs.data());
|
||||
AOTInductorTensorHandle outputs_handle =
|
||||
reinterpret_cast<AOTInductorTensorHandle>(outputs.data());
|
||||
|
||||
AOTInductorProxyExecutorHandle proxy_executor_handle = nullptr;
|
||||
|
||||
AOT_INDUCTOR_ERROR_CHECK(AOTInductorModelContainerRun(
|
||||
container_handle,
|
||||
inputs_handle,
|
||||
inputs.size(),
|
||||
outputs_handle,
|
||||
outputs.size(),
|
||||
stream_handle));
|
||||
stream_handle,
|
||||
proxy_executor_handle));
|
||||
|
||||
ASSERT_TRUE(torch::allclose(results_ref, outputs[0]));
|
||||
AOT_INDUCTOR_ERROR_CHECK(AOTInductorModelContainerDelete(container_handle));
|
||||
|
@ -525,6 +525,8 @@ class GraphModuleSerializer:
|
||||
)
|
||||
|
||||
def serialize_input(self, arg) -> Argument:
|
||||
import torch._inductor.ir as inductor_ir
|
||||
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
if arg.op == "get_attr":
|
||||
assert isinstance(arg.target, str)
|
||||
@ -545,6 +547,13 @@ class GraphModuleSerializer:
|
||||
return Argument.create(as_sym_bool=SymBoolArgument.create(as_name=arg.name))
|
||||
else:
|
||||
return Argument.create(as_tensor=TensorArgument(name=arg.name))
|
||||
elif isinstance(arg, (inductor_ir.InputBuffer, inductor_ir.ComputedBuffer)):
|
||||
# Other branches are for arguments in fx node.
|
||||
# This is a special branch for handling buffers (representing tensor arguments)
|
||||
# for inductor's ExternalFallbackNode
|
||||
# export_extern_kernel_node() is using this function to serialize arguments
|
||||
assert arg.name is not None, "Input buffer must have valid name"
|
||||
return Argument.create(as_tensor=TensorArgument(name=arg.name))
|
||||
elif isinstance(arg, bool):
|
||||
return Argument.create(as_bool=arg)
|
||||
elif isinstance(arg, str):
|
||||
@ -594,7 +603,8 @@ class GraphModuleSerializer:
|
||||
self.graph_state.constants[a.name] = attr
|
||||
arguments.append(TensorArgument(name=a.name))
|
||||
return Argument.create(as_tensors=arguments)
|
||||
elif any(isinstance(a, torch.fx.Node) for a in arg):
|
||||
elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg):
|
||||
# list of optional tensors
|
||||
def serialize_optional_tensor_args(a):
|
||||
if a is None:
|
||||
return OptionalTensorArgument.create(as_none=())
|
||||
@ -605,6 +615,23 @@ class GraphModuleSerializer:
|
||||
return Argument.create(
|
||||
as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
|
||||
)
|
||||
elif all(isinstance(a, (inductor_ir.InputBuffer, inductor_ir.ComputedBuffer)) for a in arg):
|
||||
# list of tensors
|
||||
return Argument.create(
|
||||
as_tensors=[TensorArgument(name=a.name) for a in arg],
|
||||
)
|
||||
elif all(isinstance(a, (inductor_ir.InputBuffer, inductor_ir.ComputedBuffer, type(None))) for a in arg):
|
||||
# list of optional tensors
|
||||
def serialize_optional_tensor_args(a):
|
||||
if a is None:
|
||||
return OptionalTensorArgument.create(as_none=())
|
||||
elif isinstance(a, torch._inductor.ir.InputBuffer):
|
||||
return OptionalTensorArgument.create(as_tensor=a.name)
|
||||
else:
|
||||
raise SerializeError(f"Unsupported list/tuple argument: {a}")
|
||||
return Argument.create(
|
||||
as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
|
||||
)
|
||||
else:
|
||||
raise SerializeError(f"Unsupported list/tuple argument type: {type(arg[0])}")
|
||||
elif isinstance(arg, torch.dtype):
|
||||
|
@ -49,7 +49,13 @@ def aot_compile(
|
||||
gm,
|
||||
example_inputs,
|
||||
config_patches=options,
|
||||
)()
|
||||
)
|
||||
|
||||
# AOTInductor returns result as a string, not callable
|
||||
# Maybe this check is not neded?
|
||||
if callable(result):
|
||||
result = result()
|
||||
|
||||
lib_path = result[0] if isinstance(result, (list, tuple)) else result
|
||||
return lib_path
|
||||
|
||||
|
@ -943,7 +943,7 @@ class AotCodeCache:
|
||||
clear = staticmethod(cache.clear)
|
||||
|
||||
@classmethod
|
||||
def compile(cls, graph, source_code, cuda):
|
||||
def compile(cls, graph, source_code, serialized_extern_kernel_nodes, cuda):
|
||||
# TODO: update cpp_compile_command for different platforms
|
||||
picked_vec_isa = invalid_vec_isa if cuda else pick_vec_isa()
|
||||
cpp_command = repr(
|
||||
@ -963,6 +963,13 @@ class AotCodeCache:
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
with lock:
|
||||
# Currently, this only support serializing extern nodes in fbcode
|
||||
# Eventually, we should also have a serializer for OSS.
|
||||
if config.is_fbcode() and serialized_extern_kernel_nodes:
|
||||
output_json = os.path.splitext(input_path)[0] + ".json"
|
||||
with open(output_json, "w") as f:
|
||||
f.write(serialized_extern_kernel_nodes)
|
||||
|
||||
output_so = os.path.splitext(input_path)[0] + ".so"
|
||||
|
||||
if not os.path.exists(output_so):
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <torch/csrc/inductor/aot_inductor_model_container.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <iostream>
|
||||
#include <torch/csrc/inductor/proxy_executor.h>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
@ -50,7 +51,8 @@ AOTInductorError AOTInductorModelContainerRun(
|
||||
size_t num_inputs,
|
||||
AOTInductorTensorHandle outputs_handle,
|
||||
size_t num_outputs,
|
||||
AOTInductorStreamHandle stream_handle) {
|
||||
AOTInductorStreamHandle stream_handle,
|
||||
AOTInductorProxyExecutorHandle proxy_executor_handle) {
|
||||
auto* container =
|
||||
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
||||
container_handle);
|
||||
@ -70,8 +72,11 @@ AOTInductorError AOTInductorModelContainerRun(
|
||||
}
|
||||
|
||||
auto stream = reinterpret_cast<cudaStream_t>(stream_handle);
|
||||
|
||||
torch::aot_inductor::ProxyExecutor* proxy_executor = reinterpret_cast<torch::aot_inductor::ProxyExecutor*>(proxy_executor_handle);
|
||||
|
||||
CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ container->run(input_tensors, output_tensors, stream); })
|
||||
{ container->run(input_tensors, output_tensors, stream, proxy_executor); })
|
||||
}
|
||||
|
||||
AOTInductorError AOTInductorModelContainerGetNumInputs(
|
||||
|
@ -110,10 +110,16 @@ VECTORIZABLE_RTYPES = {
|
||||
}
|
||||
|
||||
PYTHON_TO_CPP = {
|
||||
"Tensor": "at::Tensor",
|
||||
"int": "long",
|
||||
"float": "double",
|
||||
"bool": "bool",
|
||||
"str": "std::string",
|
||||
"ScalarType": "c10::ScalarType",
|
||||
"MemoryFormat": "at::MemoryFormat",
|
||||
"Layout": "at::Layout",
|
||||
"Device": "at::Device",
|
||||
"number": "at::Scalar",
|
||||
}
|
||||
|
||||
CONTAINER_PYTHON_TO_CPP = {
|
||||
|
@ -472,6 +472,8 @@ class WrapperCodeGen(CodeGen):
|
||||
cpp_op_schema,
|
||||
cpp_kernel_key,
|
||||
cpp_kernel_overload_name="",
|
||||
op_overload=None,
|
||||
raw_args=None,
|
||||
):
|
||||
self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})")
|
||||
|
||||
@ -916,6 +918,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
||||
self.supports_intermediate_hooks = False
|
||||
self.outputs_need_copy = set()
|
||||
self.resized_outputs = {}
|
||||
self.kernel_callsite_id = count()
|
||||
|
||||
from .cpp import cexpr
|
||||
|
||||
@ -975,7 +978,8 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
||||
void AOTInductorModel::run_impl(
|
||||
const std::vector<at::Tensor>& args,
|
||||
std::vector<at::Tensor>& outputs,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream,
|
||||
ProxyExecutor* proxy_executor) {
|
||||
"""
|
||||
)
|
||||
else:
|
||||
@ -1298,6 +1302,112 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
||||
f"{self.codegen_tensor_option(device, dtype)};"
|
||||
)
|
||||
|
||||
def generate_extern_kernel_args_decl_if_needed(
|
||||
self, op_overload, raw_args, output_args
|
||||
):
|
||||
arg_types = [x.real_type for x in op_overload._schema.arguments]
|
||||
return_types = [x.type for x in op_overload._schema.returns]
|
||||
|
||||
new_tensor_args = []
|
||||
new_int_args = []
|
||||
|
||||
def fill_args(arg, arg_type):
|
||||
static_arg_types = (
|
||||
torch.FloatType,
|
||||
torch.BoolType,
|
||||
torch.StringType,
|
||||
torch.Type,
|
||||
torch.DeviceObjType,
|
||||
)
|
||||
|
||||
if isinstance(arg_type, torch.TensorType):
|
||||
assert isinstance(arg, (ir.InputBuffer, ir.ComputedBuffer))
|
||||
new_tensor_args.append(f"&{arg.name}")
|
||||
elif isinstance(arg_type, (torch.IntType, torch.SymIntType)):
|
||||
# int or SymInt
|
||||
assert isinstance(arg, int)
|
||||
new_int_args.append(str(arg))
|
||||
elif isinstance(arg_type, torch.NumberType):
|
||||
# Scalar of type int
|
||||
assert isinstance(arg, (int, float, bool))
|
||||
# Only treat int Scalar as dynamic
|
||||
if isinstance(arg, int):
|
||||
new_int_args.append(str(arg))
|
||||
elif isinstance(arg_type, torch.ListType):
|
||||
assert isinstance(arg, (list, tuple))
|
||||
|
||||
# List[Tensor]
|
||||
if isinstance(arg_type.getElementType(), torch.TensorType):
|
||||
new_tensor_args.extend([f"&{a.name}" for a in arg])
|
||||
# List[Optional[Tensor]]
|
||||
elif isinstance(
|
||||
arg_type.getElementType(), torch.OptionalType
|
||||
) and isinstance(
|
||||
arg_type.getElementType().getElementType(), torch.TensorType
|
||||
):
|
||||
new_tensor_args.extend([f"&{a.name}" for a in arg if a is not None])
|
||||
# List [int] or List[SymInt]
|
||||
elif isinstance(
|
||||
arg_type.getElementType(), (torch.IntType, torch.SymIntType)
|
||||
):
|
||||
new_int_args.extend([str(a) for a in arg])
|
||||
# List[Scalar]
|
||||
elif isinstance(arg_type.getElementType(), torch.NumberType):
|
||||
# Only treat int Scalar as dynamic
|
||||
is_int_type = [isinstance(a, int) for a in arg]
|
||||
if any(is_int_type):
|
||||
assert all(
|
||||
is_int_type
|
||||
), "AOTInductor only supports int scalars of the same type"
|
||||
new_int_args.extend([str(a) for a in arg])
|
||||
else:
|
||||
assert isinstance(
|
||||
arg_type.getElementType(), static_arg_types
|
||||
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
||||
else:
|
||||
assert isinstance(
|
||||
arg_type, static_arg_types
|
||||
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
||||
|
||||
for arg, arg_type in zip(raw_args, arg_types):
|
||||
if arg is not None:
|
||||
if isinstance(arg_type, torch.OptionalType):
|
||||
fill_args(arg, arg_type.getElementType())
|
||||
else:
|
||||
fill_args(arg, arg_type)
|
||||
|
||||
def fill_output_arg(arg, return_type):
|
||||
if isinstance(return_type, torch.TensorType):
|
||||
self.writeline(f"at::Tensor {arg}; // output buffer")
|
||||
new_tensor_args.append(f"&{output_arg}")
|
||||
elif isinstance(return_type, torch.ListType) and isinstance(
|
||||
return_type.getElementType(), torch.TensorType
|
||||
):
|
||||
# TODO: handle tensor list return type
|
||||
raise NotImplementedError("NYI support for return type: List[Tensor]")
|
||||
elif isinstance(return_type, torch.SymIntType):
|
||||
raise NotImplementedError("NYI support for return type: SymInt")
|
||||
elif isinstance(return_type, torch.ListType) and isinstance(
|
||||
return_type.getElementType(), torch.SymIntType
|
||||
):
|
||||
raise NotImplementedError("NYI support for return type: List[SymInt]")
|
||||
else:
|
||||
raise AssertionError(f"Unsupport return type found: {return_type}")
|
||||
|
||||
assert (
|
||||
len(output_args) == 1
|
||||
), "Support for multiple returns is not implemented yet"
|
||||
for output_arg, return_type in zip(output_args, return_types):
|
||||
# TODO: check schema here
|
||||
# assume it's a tensor now
|
||||
if output_arg is not None:
|
||||
if isinstance(return_type, torch.OptionalType):
|
||||
fill_output_arg(output_arg, return_type.getElementType())
|
||||
else:
|
||||
fill_output_arg(output_arg, return_type)
|
||||
|
||||
return new_tensor_args, new_int_args
|
||||
|
||||
def generate_extern_kernel_alloc_and_find_schema_if_needed(
|
||||
self,
|
||||
name,
|
||||
@ -1306,6 +1416,37 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
||||
cpp_op_schema,
|
||||
cpp_kernel_key,
|
||||
cpp_kernel_overload_name="",
|
||||
op_overload=None,
|
||||
raw_args=None,
|
||||
):
|
||||
if config.is_fbcode():
|
||||
assert op_overload is not None
|
||||
assert raw_args is not None
|
||||
|
||||
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
|
||||
name,
|
||||
cpp_kernel_key,
|
||||
op_overload,
|
||||
raw_args,
|
||||
)
|
||||
else:
|
||||
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
|
||||
name,
|
||||
kernel,
|
||||
codegen_args,
|
||||
cpp_op_schema,
|
||||
cpp_kernel_key,
|
||||
cpp_kernel_overload_name,
|
||||
)
|
||||
|
||||
def generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
|
||||
self,
|
||||
name,
|
||||
kernel,
|
||||
codegen_args,
|
||||
cpp_op_schema,
|
||||
cpp_kernel_key,
|
||||
cpp_kernel_overload_name="",
|
||||
):
|
||||
if cpp_kernel_key not in self.extern_call_ops:
|
||||
self.writeline(
|
||||
@ -1321,6 +1462,43 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
||||
f"auto {name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});"
|
||||
)
|
||||
|
||||
def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
|
||||
self,
|
||||
name,
|
||||
cpp_kernel_key,
|
||||
op_overload,
|
||||
raw_args, # contains both args and flatten kwargs
|
||||
):
|
||||
output_args = [name]
|
||||
|
||||
(
|
||||
tensor_call_args,
|
||||
int_call_args,
|
||||
) = self.generate_extern_kernel_args_decl_if_needed(
|
||||
op_overload, raw_args, output_args
|
||||
)
|
||||
|
||||
tensor_args_var = f"tensor_args_var_{next(self.kernel_callsite_id)}"
|
||||
tensor_call_args_str = ", ".join(tensor_call_args)
|
||||
self.writeline(f"void* {tensor_args_var}[] = {{{tensor_call_args_str}}};")
|
||||
|
||||
int_args_var = f"int_args_var_{next(self.kernel_callsite_id)}"
|
||||
int_call_args_str = ", ".join(int_call_args)
|
||||
self.writeline(f"int64_t {int_args_var}[] = {{{int_call_args_str}}};")
|
||||
|
||||
extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
|
||||
|
||||
self.writeline(
|
||||
f"proxy_executor->call_function("
|
||||
f"{extern_kernel_node_index}, "
|
||||
f"{len(int_call_args)}, "
|
||||
f"{int_args_var}, "
|
||||
f"{len(tensor_call_args)}, "
|
||||
f"{tensor_args_var});"
|
||||
)
|
||||
|
||||
self.extern_call_ops.add(cpp_kernel_key)
|
||||
|
||||
def val_to_arg_str(self, val):
|
||||
from .cpp import DTYPE_TO_ATEN
|
||||
|
||||
@ -1353,7 +1531,6 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.kernel_callsite_id = count()
|
||||
self.arg_var_id = count()
|
||||
self.cuda = True
|
||||
|
||||
|
@ -39,6 +39,7 @@ from .fx_passes.joint_graph import joint_graph_passes
|
||||
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
|
||||
from .fx_passes.pre_grad import pre_grad_passes
|
||||
from .graph import GraphLowering
|
||||
from .ir import ExternKernelNode
|
||||
from .pattern_matcher import clone_graph
|
||||
from .utils import get_dtype_size, has_incompatible_cudagraph_ops
|
||||
from .virtualized import V
|
||||
@ -299,6 +300,7 @@ def compile_fx_inner(
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
||||
user_visible_outputs: FrozenSet[str] = frozenset(),
|
||||
layout_opt: Optional[bool] = None,
|
||||
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
|
||||
):
|
||||
"""
|
||||
Inductor API that compiles a single graph.
|
||||
@ -340,6 +342,7 @@ def compile_fx_inner(
|
||||
"is_inference": is_inference,
|
||||
"user_visible_outputs": user_visible_outputs,
|
||||
"layout_opt": layout_opt,
|
||||
"extern_node_serializer": extern_node_serializer,
|
||||
}
|
||||
|
||||
compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
|
||||
@ -493,6 +496,7 @@ def fx_codegen_and_compile(
|
||||
is_inference: bool = False,
|
||||
user_visible_outputs: FrozenSet[str] = frozenset(),
|
||||
layout_opt: Optional[bool] = None,
|
||||
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
|
||||
) -> CompiledFxGraph:
|
||||
if is_tf32_warning_applicable(gm):
|
||||
_warn_tf32_disabled()
|
||||
@ -549,6 +553,7 @@ def fx_codegen_and_compile(
|
||||
cpp_wrapper=cpp_wrapper,
|
||||
aot_mode=aot_mode,
|
||||
user_visible_outputs=user_visible_outputs,
|
||||
extern_node_serializer=extern_node_serializer,
|
||||
)
|
||||
with V.set_graph_handler(graph): # type: ignore[call-arg]
|
||||
graph.run(*example_inputs)
|
||||
@ -864,11 +869,16 @@ def compile_fx_aot(
|
||||
"aot_inductor_output_path": code_hash(model_.code),
|
||||
}
|
||||
|
||||
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
||||
with mock.patch.object(_in_aot_compilation, "value", True):
|
||||
return compile_fx(
|
||||
model_,
|
||||
example_inputs_,
|
||||
inner_compile=functools.partial(inner_compile, aot_mode=True),
|
||||
inner_compile=functools.partial(
|
||||
inner_compile,
|
||||
aot_mode=True,
|
||||
extern_node_serializer=extern_node_serializer,
|
||||
),
|
||||
config_patches=config_patches,
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,7 @@ import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import DefaultDict, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import sympy
|
||||
|
||||
@ -165,6 +165,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
aot_mode=False,
|
||||
user_visible_outputs=frozenset(),
|
||||
layout_opt=None,
|
||||
extern_node_serializer=None,
|
||||
):
|
||||
super().__init__(gm)
|
||||
|
||||
@ -196,6 +197,11 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.mutated_buffers: Set[str] = set()
|
||||
self.inplaced_to_remove: Set[str] = set()
|
||||
self.wrapper_code: Optional[WrapperCodeGen] = None
|
||||
# See `ProxyExecutor Design Note` in ir.py for more details
|
||||
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
|
||||
self.extern_node_serializer: Optional[
|
||||
Callable[[List[ir.ExternKernelNode]], Any]
|
||||
] = extern_node_serializer
|
||||
self.current_node: Optional[torch.fx.Node] = None
|
||||
self.num_static_inputs = num_static_inputs
|
||||
self.lists: Dict[str, List[str]] = {}
|
||||
@ -963,8 +969,24 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
code, linemap = self.codegen()
|
||||
output_code_log.debug("Output code: \n%s", code)
|
||||
|
||||
serialized_extern_kernel_nodes = None
|
||||
if (
|
||||
config.is_fbcode()
|
||||
and self.extern_kernel_nodes
|
||||
and self.extern_node_serializer
|
||||
):
|
||||
serialized_extern_kernel_nodes = self.extern_node_serializer(
|
||||
self.extern_kernel_nodes
|
||||
)
|
||||
output_code_log.debug(
|
||||
"Serialized Extern Kernel Nodes: \n%s",
|
||||
serialized_extern_kernel_nodes,
|
||||
)
|
||||
|
||||
# Directly return the file path with the compiled code
|
||||
return AotCodeCache.compile(self, code, cuda=self.cuda)
|
||||
return AotCodeCache.compile(
|
||||
self, code, serialized_extern_kernel_nodes, cuda=self.cuda
|
||||
)
|
||||
else:
|
||||
return self.compile_to_module().call
|
||||
|
||||
|
@ -28,11 +28,14 @@ from unittest.mock import patch
|
||||
import sympy
|
||||
from sympy import Expr, Integer
|
||||
|
||||
import torch._export.serde.schema as export_schema
|
||||
|
||||
import torch._logging
|
||||
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import identity
|
||||
from torch._export.serde.serialize import GraphModuleSerializer
|
||||
from torch._prims_common import (
|
||||
compute_required_storage_length,
|
||||
is_boolean_dtype,
|
||||
@ -3566,6 +3569,12 @@ class DynamicScalar(IRNode):
|
||||
return ()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExternKernelNode:
|
||||
name: str
|
||||
node: export_schema.Node
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FallbackKernel(ExternKernelAlloc):
|
||||
def __init__(
|
||||
@ -3585,6 +3594,8 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
)
|
||||
self.use_cpp_op_schema = False
|
||||
|
||||
self.op_overload = kernel
|
||||
|
||||
op_overload_packet = (
|
||||
kernel._overloadpacket
|
||||
if isinstance(kernel, torch._ops.OpOverload)
|
||||
@ -3692,8 +3703,48 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
return devices[0]
|
||||
return None
|
||||
|
||||
# ProxyExecutor Design Note
|
||||
# We export the ExternFallbackNodes (for custom ops) into a serialized file
|
||||
# and run it with a host side proxy executor to address the ABI problem
|
||||
# This is currently only implemented for fbcode. Eventually, we will also make this work for OSS.
|
||||
# Detailed design doc can be found at
|
||||
# https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing
|
||||
def export_extern_kernel_node(self):
|
||||
args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
|
||||
ordered_kwargs = [
|
||||
kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel
|
||||
]
|
||||
|
||||
serializer = GraphModuleSerializer(None, None, None)
|
||||
named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs)
|
||||
|
||||
# TODO: only single output is supported
|
||||
output_arguments = [
|
||||
export_schema.Argument.create(
|
||||
as_tensor=export_schema.TensorArgument(name=self.get_name())
|
||||
)
|
||||
]
|
||||
|
||||
node = ExternKernelNode(
|
||||
name=self.get_name(),
|
||||
node=export_schema.Node(
|
||||
target=self.kernel,
|
||||
inputs=named_arguments,
|
||||
outputs=output_arguments,
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
V.graph.extern_kernel_nodes.append(node)
|
||||
|
||||
return [*args, *ordered_kwargs]
|
||||
|
||||
def codegen(self, wrapper):
|
||||
if self.use_cpp_op_schema:
|
||||
exported_args = None
|
||||
if config.is_fbcode():
|
||||
exported_args = self.export_extern_kernel_node()
|
||||
|
||||
args = [*self.codegen_args(), *self.codegen_kwargs()]
|
||||
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
||||
self.get_name(),
|
||||
@ -3702,6 +3753,8 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
self.cpp_op_schema,
|
||||
self.cpp_kernel_key,
|
||||
self.cpp_kernel_overlad_name,
|
||||
self.op_overload,
|
||||
exported_args,
|
||||
)
|
||||
else:
|
||||
super().codegen(wrapper)
|
||||
|
@ -44,6 +44,9 @@ using AOTInductorStreamHandle = AOTInductorStreamOpaque*;
|
||||
struct AOTInductorTensorOpaque {};
|
||||
using AOTInductorTensorHandle = AOTInductorTensorOpaque*;
|
||||
|
||||
struct AOTInductorProxyExecutorOpaque {};
|
||||
using AOTInductorProxyExecutorHandle = AOTInductorProxyExecutorOpaque*;
|
||||
|
||||
extern "C" {
|
||||
// Creates an AOTInductor model container. The parameter num_models
|
||||
// specifies the number of model instances that may be run concurrently for
|
||||
@ -63,7 +66,8 @@ AOTInductorError AOTInductorModelContainerRun(
|
||||
size_t num_inputs,
|
||||
AOTInductorTensorHandle outputs_handle,
|
||||
size_t num_outputs,
|
||||
AOTInductorStreamHandle stream_handle);
|
||||
AOTInductorStreamHandle stream_handle,
|
||||
AOTInductorProxyExecutorHandle proxy_executor_handle);
|
||||
|
||||
// Retrieves the number of inputs for the model.
|
||||
AOTInductorError AOTInductorModelContainerGetNumInputs(
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/inductor/proxy_executor.h>
|
||||
|
||||
#define AOT_VECTOR_SIZE_CHECK(vec, expected_size) \
|
||||
{ \
|
||||
@ -51,12 +52,13 @@ class AOTInductorModelBase {
|
||||
void run(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
std::vector<at::Tensor>& outputs,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream,
|
||||
ProxyExecutor* proxy_executor = nullptr) {
|
||||
AOT_VECTOR_SIZE_CHECK(inputs, num_inputs());
|
||||
AOT_VECTOR_SIZE_CHECK(outputs, num_outputs());
|
||||
|
||||
auto* model = static_cast<Model*>(this);
|
||||
model->run_impl(inputs, outputs, stream);
|
||||
model->run_impl(inputs, outputs, stream, proxy_executor);
|
||||
C10_CUDA_CHECK(cudaEventRecord(run_finished_, stream));
|
||||
}
|
||||
|
||||
@ -178,7 +180,8 @@ class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> {
|
||||
void run_impl(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
std::vector<at::Tensor>& outputs,
|
||||
cudaStream_t stream);
|
||||
cudaStream_t stream,
|
||||
ProxyExecutor* proxy_executor = nullptr);
|
||||
|
||||
static std::unique_ptr<AOTInductorModel> Create() {
|
||||
return std::make_unique<AOTInductorModel>();
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <torch/csrc/inductor/aot_inductor_model.h>
|
||||
#include <torch/csrc/inductor/proxy_executor.h>
|
||||
|
||||
namespace torch {
|
||||
namespace aot_inductor {
|
||||
@ -56,12 +57,13 @@ class AOTInductorModelContainer {
|
||||
void run(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
std::vector<at::Tensor>& outputs,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream,
|
||||
ProxyExecutor* proxy_executor) {
|
||||
auto* model = get_available_model();
|
||||
try {
|
||||
AOT_VECTOR_SIZE_CHECK(inputs, num_inputs());
|
||||
AOT_VECTOR_SIZE_CHECK(outputs, num_outputs());
|
||||
model->run(inputs, outputs, stream);
|
||||
model->run(inputs, outputs, stream, proxy_executor);
|
||||
} catch (...) {
|
||||
std::lock_guard lk(models_mutex_);
|
||||
available_models_.push_back(model);
|
||||
|
24
torch/csrc/inductor/proxy_executor.h
Normal file
24
torch/csrc/inductor/proxy_executor.h
Normal file
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace aot_inductor {
|
||||
|
||||
class TORCH_API ProxyExecutor : public torch::CustomClassHolder {
|
||||
public:
|
||||
ProxyExecutor() {}
|
||||
virtual ~ProxyExecutor() {}
|
||||
|
||||
virtual void call_function(
|
||||
int extern_node_index,
|
||||
int num_ints,
|
||||
int64_t* flatten_int_args,
|
||||
int num_tensors,
|
||||
void** flatten_tensor_args) = 0;
|
||||
};
|
||||
|
||||
} // namespace aot_inductor
|
||||
} // namespace torch
|
Reference in New Issue
Block a user