Add registration API for torch.compile-eager (#121387)

This PR is a follow-up of RFC https://github.com/pytorch/pytorch/issues/115545.

In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows.

- Load cache
- Check cache according to the input tensors
  - Cache Hit: Run the cached kernel directly
  - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function.

Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - https://github.com/pytorch/pytorch/pull/116368

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121387
Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/zou3519, https://github.com/jgong5
This commit is contained in:
Wang, Eikan
2024-04-27 05:10:45 +00:00
committed by PyTorch MergeBot
parent 620d808da0
commit 61e937f3d6
6 changed files with 446 additions and 0 deletions

View File

@ -824,6 +824,7 @@ libtorch_python_core_sources = [
"torch/csrc/mps/Module.cpp",
"torch/csrc/mtia/Module.cpp",
"torch/csrc/inductor/aoti_runner/pybind.cpp",
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
"torch/csrc/jit/backends/backend_init.cpp",
"torch/csrc/jit/python/init.cpp",
"torch/csrc/jit/passes/onnx.cpp",

View File

@ -46,6 +46,7 @@ from torch._inductor.utils import (
from torch._inductor.virtualized import V
from torch._prims_common import is_integer_dtype
from torch.fx.experimental.proxy_tensor import make_fx
from torch.library import _scoped_library
from torch.nn import functional as F
from torch.testing import FileCheck, make_tensor
from torch.testing._internal.common_cuda import (
@ -759,6 +760,70 @@ class CommonTemplate:
),
)
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_torch_compile_override_registration(self):
dynamic = False
namespace_name = "aten"
dispatch_key = "CPU"
device = torch.device("cpu")
if self.device.lower() == "cuda":
dispatch_key = "CUDA"
device = torch.device("cuda")
unary_op_set = ["abs", "acos"]
def fn(x, op_name=""):
return getattr(torch, op_name)(x)
# Invoke torch.compile directly to get referent results
x = torch.randn(3, 4, device=device)
ref_array = []
for unary_op_name in unary_op_set:
opt_fn = torch.compile(functools.partial(fn, op_name=unary_op_name))
ref = opt_fn(x)
ref_array.append(ref)
def register_ops(op_set, dispatch_key, torch_compile_op_lib_impl):
for _op_name in op_set:
qualified_op_name = f"{namespace_name}::{_op_name}"
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
for overload_name in overload_names:
try:
reg_op_name = qualified_op_name
schema = torch._C._get_schema(qualified_op_name, overload_name)
if schema.overload_name:
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
reg_op_name, dispatch_key
)
except Exception as e:
continue
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
register_ops(unary_op_set, dispatch_key, torch_compile_op_lib_impl)
res_array = []
for unary_op_name in unary_op_set:
res_array.append(getattr(torch, unary_op_name)(x))
for ref, res in zip(ref_array, res_array):
self.assertEqual(ref, res)
a = torch.randn(128, device=device)
min_tensor = torch.randn(128, device=device)
max_tensor = min_tensor + 0.5
ref_with_min = torch.ops.aten.clamp(a, min_tensor)
ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
register_ops(["clamp"], dispatch_key, torch_compile_op_lib_impl)
res_with_min = torch.ops.aten.clamp(a, min_tensor)
res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
self.assertEqual(ref_with_min, res_with_min)
self.assertEqual(ref_with_min_max, res_with_min_max)
def test_add_const_int(self):
def fn(a):
return (a + 1, torch.add(a, 1, alpha=2))

View File

@ -0,0 +1,246 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include <torch/csrc/jit/frontend/function_schema_parser.h>
namespace torch::inductor {
namespace {
inline void unpack_tensor_ivalue(
const c10::IValue& ivalue,
const c10::Device& device,
std::vector<at::Tensor>& inputs) {
inputs.push_back(ivalue.toTensor());
}
inline void unpack_optional_tensor_ivalue(
const c10::IValue& ivalue,
const c10::Device& device,
std::vector<at::Tensor>& inputs) {
auto ivalue_opt_tensor = ivalue.toOptional<at::Tensor>();
if (ivalue_opt_tensor.has_value()) {
inputs.push_back(ivalue_opt_tensor.value());
}
}
inline void unpack_tensor_list_ivalue(
const c10::IValue& ivalue,
const c10::Device& device,
std::vector<at::Tensor>& inputs) {
for (const auto& item : ivalue.toListRef()) {
inputs.push_back(item.toTensor());
}
}
inline void unpack_optional_tensor_list_ivalue(
const c10::IValue& ivalue,
const c10::Device& device,
std::vector<at::Tensor>& inputs) {
for (const auto& item : ivalue.toListRef()) {
unpack_optional_tensor_ivalue(item, device, inputs);
}
}
inline void unpack_scalar_ivalue(
const c10::IValue& ivalue,
const c10::Device& device,
std::vector<at::Tensor>& inputs) {
inputs.push_back(at::scalar_tensor(
ivalue.toScalar(),
c10::TensorOptions().device(device).dtype(ivalue.toScalar().type())));
}
bool unpack_ivalue(
const c10::Argument& argument,
const c10::IValue& ivalue,
const c10::Device& device,
std::vector<at::Tensor>& inputs) {
if (ivalue.isTensor()) {
unpack_tensor_ivalue(ivalue, device, inputs);
} else if (ivalue.isTensorList()) {
unpack_tensor_list_ivalue(ivalue, device, inputs);
} else if (ivalue.isOptionalTensorList()) {
unpack_optional_tensor_list_ivalue(ivalue, device, inputs);
} else if (ivalue.isScalar()) {
// ivalue is scalar
unpack_scalar_ivalue(ivalue, device, inputs);
} else if (
*argument.real_type() == *c10::getTypePtr<c10::optional<at::Tensor>>()) {
// ivalue is c10::optional<at::Tensor>
unpack_optional_tensor_ivalue(ivalue, device, inputs);
} else {
// Unsupport IValue type.
return false;
}
return true;
}
bool unpack_tensors(
const std::vector<c10::Argument>& arguments,
const torch::jit::Stack& stack,
const c10::Device& device,
std::vector<at::Tensor>& inputs) {
for (size_t idx = 0; idx < stack.size(); idx++) {
if (!unpack_ivalue(arguments[idx], stack[idx], device, inputs)) {
return false;
}
}
return true;
}
} // namespace
AOTIPythonKernelHolder::AOTIPythonKernelHolder(
c10::DispatchKey dispatch_key,
c10::string_view ns,
c10::string_view op_name_with_overload)
: dispatch_key_(dispatch_key),
ns_(std::string(ns)),
op_name_with_overload_(std::string(op_name_with_overload)),
device_(c10::dispatchKeyToDeviceType(dispatch_key_), 0),
pyinterpreter_(getPyInterpreter()) {
TORCH_CHECK(
(device_.type() == c10::DeviceType::CPU) ||
(device_.type() == c10::DeviceType::CUDA),
"Unsupported device type");
}
void AOTIPythonKernelHolder::operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
if (cache_lookup(op, keyset, stack)) {
cache_hit(op, keyset, stack);
} else {
cache_miss(op, keyset, stack);
}
}
bool AOTIPythonKernelHolder::cache_lookup(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
// TODO: Always return false now to implement cache_miss. Later, we will add
// cache lookup and implement cache hit.
return false;
}
void AOTIPythonKernelHolder::cache_hit(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(false);
}
void AOTIPythonKernelHolder::cache_miss(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
auto kernel_lib_path = produce_aoti_kernel_lib(op, keyset, stack);
std::shared_ptr<AOTIModelContainerRunner> kernel = nullptr;
// TODO: To enable the plugin mechanism to allow registration for other
// backends
if (device_.type() == c10::DeviceType::CPU) {
kernel = std::make_shared<AOTIModelContainerRunnerCpu>(kernel_lib_path);
} else {
#ifdef USE_CUDA
kernel = std::make_shared<AOTIModelContainerRunnerCuda>(kernel_lib_path);
#else
TORCH_CHECK(false, "Unsupported CUDA device type");
#endif
}
std::vector<at::Tensor> inputs;
TORCH_INTERNAL_ASSERT(
unpack_tensors(op.schema().arguments(), *stack, device_, inputs),
"Failed to unpack tensors for the stack to run the AOTI kernel.");
auto outputs = kernel->run(inputs);
if (outputs.size() > 0) {
torch::jit::drop(*stack, op.schema().arguments().size());
// TODO: Get the output type of this operation and then convert to the
// output type.
for (auto& output : outputs) {
torch::jit::push(*stack, std::move(output));
}
}
}
std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
auto arguments = torch::jit::last(*stack, op.schema().arguments().size());
py::gil_scoped_acquire gil;
// Get the corresponding python operation for the current operator and the
// python operation will pass to the AOT Inductor to generate the kernel
// library.
const auto& schema = op.schema();
const auto& qualified_name = op.operator_name().name;
const auto& overload_name =
schema.overload_name().empty() ? "default" : schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
py::handle op_py_func = op.getPythonOp(pyinterpreter_, [&]() -> PyObject* {
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
return torch_api_function.attr(overload_name.c_str()).ptr();
});
TORCH_INTERNAL_ASSERT(
op_py_func.ptr() != nullptr && op_py_func.ptr() != Py_None,
"Failed to get python operation. Operator Name is ",
op.operator_name().name,
", Overload Name is ",
overload_name);
py::handle aot_compile_function =
py::module::import("torch._export").attr("aot_compile");
TORCH_INTERNAL_ASSERT(
aot_compile_function.ptr() != nullptr &&
aot_compile_function.ptr() != Py_None,
"Failed to import - torch._export.aot_compile");
// Pass the python operation to the AOT Inductor to generate the kernel
// library.
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments.vec());
auto result = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
aot_compile_function.ptr(),
op_py_func.ptr(),
args_kwargs.first.ptr(),
args_kwargs.second.ptr(),
nullptr));
TORCH_INTERNAL_ASSERT(result.ptr() != nullptr && result.ptr() != Py_None);
auto kernel_lib_path = py::cast<std::string>(result);
TORCH_CHECK(
!kernel_lib_path.empty(),
"Failed to produce kernel libarary by using AOTI for ",
c10::DeviceTypeName(device_.type()),
". Operator Name is ",
op.operator_name().name,
", Overload Name is ",
op.schema().overload_name());
return kernel_lib_path;
}
} // namespace torch::inductor
#endif

View File

@ -0,0 +1,65 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <torch/csrc/utils/pybind.h>
#include <string>
namespace torch::inductor {
// The AOTIPythonKernelHolder class uses the AOT Inductor to generate a kernel
// for a specified operation. To speed up this process, the generated kernel
// library is cached on disk. Detailed information from the input tensors is
// used as the key for caching the kernel library. On subsequent runs, these
// input tensors are used to search the cache. If a cache hit occurs, the cached
// kernel library is loaded and executed. If a cache miss occurs, the AOT
// Inductor is called again to generate the kernel library.
class AOTIPythonKernelHolder : public c10::OperatorKernel {
// A DispatchKey object that represents the dispatch key for the kernel.
c10::DispatchKey dispatch_key_;
// Namespace of the kernel.
std::string ns_;
// Name of the operation the kernel performs.
std::string op_name_with_overload_;
// The device on which the kernel is to be executed.
c10::Device device_;
// The Python interpreter to get OpOverload object with the given op_name and
// op_overload_name.
c10::impl::PyInterpreter* pyinterpreter_;
public:
AOTIPythonKernelHolder(
c10::DispatchKey dispatch_key,
c10::string_view ns,
c10::string_view op_name_with_overload);
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack);
private:
bool cache_lookup(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack);
void cache_miss(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack);
void cache_hit(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack);
std::string produce_aoti_kernel_lib(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack);
};
} // namespace torch::inductor
#endif

View File

@ -21,6 +21,7 @@
#include <c10/util/flat_hash_map.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_raii.h>
@ -372,6 +373,32 @@ void initDispatchBindings(PyObject* module) {
py::arg("name"),
py::arg("dispatch") = "",
py::arg("debug") = "impl_t_t")
.def(
"impl_with_aoti_compile",
[](const py::object& self,
const char* ns,
const char* op_name_with_overload,
c10::DispatchKey dispatch) {
HANDLE_TH_ERRORS
std::string reg_op_name =
std::string(ns).append("::").append(op_name_with_overload);
auto& lib = self.cast<torch::Library&>();
lib.impl(
reg_op_name.c_str(),
torch::dispatch(
dispatch,
CppFunction::makeFromBoxedFunctor(
std::make_unique<
torch::inductor::AOTIPythonKernelHolder>(
dispatch, ns, op_name_with_overload))),
register_or_verify());
END_HANDLE_TH_ERRORS_PYBIND
},
"",
py::arg("ns"),
py::arg("op_name_with_overload"),
py::arg("dispatch"))
.def(
"impl",
[](const py::object& self,

View File

@ -139,6 +139,48 @@ class Library:
handle = entry.abstract_impl.register(func_to_register, source)
self._registration_handles.append(handle)
def _impl_with_aoti_compile(self, op_name, dispatch_key=''):
r'''Register the operator to use the AOTI-compiled implementation.
Args:
op_name: operator name (along with the overload) or OpOverload object.
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
the dispatch key that the library was created with.
Example::
>>> my_lib = Library("aten", "IMPL")
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
'''
if dispatch_key == '':
dispatch_key = self.dispatch_key
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
if isinstance(op_name, str):
name = op_name
elif isinstance(op_name, OpOverload):
name = op_name._schema.name
overload_name = op_name._schema.overload_name
if overload_name != '':
name = name + '.' + overload_name
else:
raise RuntimeError("_impl_with_aoti_compile should be passed either a name or an OpOverload object "
"as the first argument")
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
"'s behavior for {} dispatch key and {} namespace.".
format(name.split("::")[-1], dispatch_key, self.ns))
assert self.m is not None
impl_fn: Callable = self.m.impl_with_aoti_compile
impl_fn(self.ns, name.split("::")[-1], dispatch_key)
_impls.add(key)
self._op_impls.add(key)
def impl(self, op_name, fn, dispatch_key='', *, with_keyset=False):
r'''Registers the function implementation for an operator defined in the library.