mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a cache mechanism to accelerate torch.compile-for-eager (#116368)
This PR is a follow-up of RFC https://github.com/pytorch/pytorch/issues/115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] }, { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084001c54cca5fe3174176f9b0206ddb7dcf - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 Differential Revision: [D57216427](https://our.internmc.facebook.com/intern/diff/D57216427) Pull Request resolved: https://github.com/pytorch/pytorch/pull/116368 Approved by: https://github.com/jansel, https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
b3a8a3cbab
commit
d1f254dce8
@ -826,6 +826,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/mtia/Module.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
|
||||
"torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp",
|
||||
"torch/csrc/jit/backends/backend_init.cpp",
|
||||
"torch/csrc/jit/python/init.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
|
@ -17,7 +17,9 @@ import threading
|
||||
import time
|
||||
import typing
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -40,6 +42,9 @@ from torch._inductor.fx_passes import pad_mm
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import (
|
||||
add_scheduler_init_hook,
|
||||
aoti_compile_with_persistent_cache,
|
||||
aoti_eager_cache_dir,
|
||||
load_aoti_eager_cache,
|
||||
run_and_get_code,
|
||||
run_and_get_triton_code,
|
||||
)
|
||||
@ -761,6 +766,102 @@ class CommonTemplate:
|
||||
),
|
||||
)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_eager_aoti_cache_hit(self):
|
||||
ns = "aten"
|
||||
op_name = "abs"
|
||||
dispatch_key = "CPU"
|
||||
device = "cpu"
|
||||
if self.device.lower() == "cuda":
|
||||
dispatch_key = "CUDA"
|
||||
device = "cuda"
|
||||
|
||||
input_tensor = torch.randn(128, dtype=torch.float, device=device)
|
||||
kernel_lib_path = aoti_compile_with_persistent_cache(
|
||||
ns,
|
||||
op_name,
|
||||
device,
|
||||
False,
|
||||
getattr(torch.ops.aten, op_name),
|
||||
(input_tensor,),
|
||||
{},
|
||||
)
|
||||
self.assertTrue(Path(kernel_lib_path).exists())
|
||||
|
||||
from unittest import mock
|
||||
|
||||
# Patch the aoti_compile_with_persistent_cache as None to ensure no new kernel is generated
|
||||
with mock.patch(
|
||||
"torch._inductor.utils.aoti_compile_with_persistent_cache", None
|
||||
):
|
||||
qualified_op_name = f"{ns}::{op_name}"
|
||||
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
||||
|
||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||
# Get ref result from eager
|
||||
ref_value = getattr(torch.ops.aten, op_name)(input_tensor)
|
||||
|
||||
for overload_name in overload_names:
|
||||
try:
|
||||
reg_op_name = qualified_op_name
|
||||
schema = torch._C._get_schema(qualified_op_name, overload_name)
|
||||
if schema.overload_name:
|
||||
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
|
||||
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
||||
reg_op_name, dispatch_key
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# Invoke the pre-compiled kernel and get result.
|
||||
res_value = getattr(torch.ops.aten, op_name)(input_tensor)
|
||||
|
||||
self.assertEqual(ref_value, res_value)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_aoti_compile_with_persistent_cache(self):
|
||||
def fn(a):
|
||||
return torch.abs(a)
|
||||
|
||||
ns = "aten"
|
||||
op_name = "abs"
|
||||
|
||||
device = "cpu"
|
||||
if self.device.lower() == "cuda":
|
||||
device = "cuda"
|
||||
|
||||
input_tensor = torch.randn(128, dtype=torch.float, device=device)
|
||||
kernel_lib_path = aoti_compile_with_persistent_cache(
|
||||
ns,
|
||||
op_name,
|
||||
input_tensor.device.type,
|
||||
False,
|
||||
fn,
|
||||
args=(input_tensor,),
|
||||
kwargs={},
|
||||
)
|
||||
self.assertTrue(len(kernel_lib_path) > 0)
|
||||
|
||||
device_kernel_cache = aoti_eager_cache_dir(ns, device)
|
||||
kernel_conf = device_kernel_cache / f"{op_name}.json"
|
||||
self.assertTrue(kernel_conf.exists())
|
||||
|
||||
json_data = load_aoti_eager_cache("aten", "abs", input_tensor.device.type)
|
||||
self.assertTrue(json_data is not None)
|
||||
self.assertTrue(isinstance(json_data, list))
|
||||
self.assertTrue(len(json_data) > 0)
|
||||
|
||||
op_info = json_data[0]
|
||||
self.assertTrue(isinstance(op_info, dict))
|
||||
self.assertTrue("meta_info" in op_info)
|
||||
self.assertTrue("kernel_path" in op_info)
|
||||
kernel_libs_abs_path = []
|
||||
for item in json_data:
|
||||
kernel_path = device_kernel_cache / item["kernel_path"]
|
||||
kernel_libs_abs_path.append(kernel_path.as_posix())
|
||||
|
||||
self.assertTrue(kernel_lib_path in kernel_libs_abs_path)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_torch_compile_override_registration(self):
|
||||
dynamic = False
|
||||
|
@ -8,6 +8,7 @@ import functools
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
@ -21,6 +22,7 @@ import time
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -32,6 +34,7 @@ from typing import (
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
ValuesView,
|
||||
@ -42,6 +45,8 @@ import sympy
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._export
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
from torch.autograd import DeviceType
|
||||
@ -51,7 +56,7 @@ from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularInd
|
||||
from torch.utils._sympy.symbol import make_symbol, SymT
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
from . import config
|
||||
from .runtime.runtime_utils import ceildiv as runtime_ceildiv
|
||||
from .runtime.runtime_utils import cache_dir, ceildiv as runtime_ceildiv
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -1544,3 +1549,140 @@ def maybe_get_suppress_shape_guards_ctx():
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return shape_env.suppress_guards()
|
||||
|
||||
|
||||
def aoti_eager_cache_dir(namespace: str, device: str):
|
||||
return Path(cache_dir()) / "aoti_eager" / namespace / device
|
||||
|
||||
|
||||
def aoti_eager_op_conf_lock(op_func_name_with_overload: str):
|
||||
from filelock import FileLock
|
||||
|
||||
# Avoid circular import
|
||||
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
|
||||
|
||||
op_conf_lock_file = f"{op_func_name_with_overload}.lock"
|
||||
lock_dir = get_lock_dir()
|
||||
return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT)
|
||||
|
||||
|
||||
def load_aoti_eager_cache(ns: str, op_func_name_with_overload: str, device_type: str):
|
||||
device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
|
||||
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
|
||||
if not op_conf.exists():
|
||||
return []
|
||||
|
||||
with aoti_eager_op_conf_lock(op_func_name_with_overload):
|
||||
with open(op_conf) as f:
|
||||
json_data = json.load(f)
|
||||
for item in json_data:
|
||||
# Get absolution path for kernel library
|
||||
kernel_lib_abs_path = device_kernel_cache / item["kernel_path"]
|
||||
item["kernel_path"] = kernel_lib_abs_path.as_posix()
|
||||
|
||||
# Check if the kernel library exists
|
||||
if not kernel_lib_abs_path.exists():
|
||||
return []
|
||||
|
||||
for metadata in item["meta_info"]:
|
||||
assert not metadata[
|
||||
"is_dynamic"
|
||||
], "Only support static shape for now"
|
||||
if metadata["device_type"] == "cpu":
|
||||
metadata["device_index"] = -1
|
||||
metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1])
|
||||
|
||||
return json_data
|
||||
|
||||
|
||||
def aoti_compile_with_persistent_cache(
|
||||
ns: str,
|
||||
op_func_name_with_overload: str,
|
||||
device_type: str,
|
||||
dynamic: bool,
|
||||
f: Callable[..., Any],
|
||||
args: Tuple[Any],
|
||||
kwargs: Dict[str, Any],
|
||||
*,
|
||||
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
remove_runtime_assertions: bool = False,
|
||||
disable_constraint_solver: bool = False,
|
||||
):
|
||||
"""
|
||||
Compile the given function with persistent cache for AOTI eager mode.
|
||||
"""
|
||||
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
|
||||
assert all(
|
||||
isinstance(input, torch.Tensor) for input in flattened_inputs
|
||||
), "Only support tensor for now"
|
||||
assert not dynamic, "Only support static shape for now"
|
||||
|
||||
persistent_cache = aoti_eager_cache_dir(ns, device_type)
|
||||
persistent_cache.mkdir(parents=True, exist_ok=True)
|
||||
persistent_cache_lib = persistent_cache / "lib"
|
||||
persistent_cache_lib.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with mock.patch.dict(
|
||||
os.environ,
|
||||
{"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()},
|
||||
):
|
||||
try:
|
||||
kernel_lib_path = torch._export.aot_compile(
|
||||
f,
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options=options,
|
||||
remove_runtime_assertions=remove_runtime_assertions,
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
)
|
||||
|
||||
kernel_metadata_items = []
|
||||
for input_tensor in flattened_inputs:
|
||||
# TODO(Eikan): To add dynamic support
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata["is_dynamic"] = dynamic
|
||||
metadata["device_type"] = f"{input_tensor.device.type}"
|
||||
if is_cpu_device([input_tensor]):
|
||||
metadata["device_index"] = -1
|
||||
else:
|
||||
metadata["device_index"] = input_tensor.device.index
|
||||
metadata["dtype"] = f"{input_tensor.dtype}"
|
||||
metadata["sizes"] = list(input_tensor.size())
|
||||
metadata["strides"] = list(input_tensor.stride())
|
||||
kernel_metadata_items.append(metadata)
|
||||
|
||||
kernel_meta_info: Dict[str, Any] = {}
|
||||
kernel_meta_info["meta_info"] = kernel_metadata_items
|
||||
kernel_meta_info["kernel_path"] = (
|
||||
Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
|
||||
)
|
||||
|
||||
json_data = []
|
||||
update_json = True
|
||||
op_conf = persistent_cache / f"{op_func_name_with_overload}.json"
|
||||
mode = "r" if op_conf.exists() else "w"
|
||||
with aoti_eager_op_conf_lock(op_func_name_with_overload):
|
||||
with open(op_conf, mode) as op_conf_file:
|
||||
try:
|
||||
json_data = json.load(op_conf_file)
|
||||
except Exception as e:
|
||||
json_data = []
|
||||
|
||||
assert isinstance(json_data, list)
|
||||
for item in json_data:
|
||||
assert isinstance(item, dict)
|
||||
# Same kernel meta info already exists in the json file
|
||||
if item["meta_info"] == kernel_metadata_items:
|
||||
update_json = False
|
||||
break
|
||||
|
||||
if update_json:
|
||||
json_data.append(kernel_meta_info)
|
||||
with open(op_conf, "w") as op_conf_file:
|
||||
json.dump(json_data, op_conf_file, indent=4)
|
||||
|
||||
return kernel_lib_path
|
||||
except Exception as e:
|
||||
return ""
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/csrc/Dtype.h>
|
||||
#include <torch/csrc/PyInterpreter.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
|
||||
@ -12,6 +13,11 @@
|
||||
#endif
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
namespace {
|
||||
@ -115,14 +121,16 @@ AOTIPythonKernelHolder::AOTIPythonKernelHolder(
|
||||
(device_.type() == c10::DeviceType::CPU) ||
|
||||
(device_.type() == c10::DeviceType::CUDA),
|
||||
"Unsupported device type");
|
||||
init_aoti_kernel_cache();
|
||||
}
|
||||
|
||||
void AOTIPythonKernelHolder::operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
if (cache_lookup(op, keyset, stack)) {
|
||||
cache_hit(op, keyset, stack);
|
||||
AOTIKernelState kernel_state;
|
||||
if (cache_lookup(op, keyset, stack, kernel_state)) {
|
||||
cache_hit(kernel_state, op, keyset, stack);
|
||||
} else {
|
||||
cache_miss(op, keyset, stack);
|
||||
}
|
||||
@ -130,23 +138,190 @@ void AOTIPythonKernelHolder::operator()(
|
||||
|
||||
bool AOTIPythonKernelHolder::cache_lookup(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
// TODO: Always return false now to implement cache_miss. Later, we will add
|
||||
// cache lookup and implement cache hit.
|
||||
const c10::DispatchKeySet& keyset,
|
||||
const torch::jit::Stack* stack,
|
||||
AOTIKernelState& kernel_state) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
op.schema().returns().size() == 1,
|
||||
"Not implemented for operations that return either multiple values or no value.");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
op.schema().returns()[0].type()->isSubtypeOf(c10::TensorType::get()),
|
||||
"Not implemented for operations that return a non-Tensor value.");
|
||||
|
||||
std::vector<at::Tensor> inputs;
|
||||
auto res = unpack_tensors(op.schema().arguments(), *stack, device_, inputs);
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
res && inputs.size() > 0,
|
||||
"Not implemented for operations that contain a parameter which is ",
|
||||
"not one of the following types: at::Tensor, at::TensorList, ",
|
||||
"std::optional<at::Tensor>, std::vector<std::optional<at::Tensor>>.");
|
||||
|
||||
auto inputs_metadata = get_inputs_metadata(inputs);
|
||||
auto aoti_kernel_state = aoti_kernel_cache_.find(inputs_metadata);
|
||||
if (aoti_kernel_state == aoti_kernel_cache_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aoti_kernel_state->second.tensor_checks_.size() != inputs.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
torch::dynamo::LocalState local_state;
|
||||
local_state.overrideDispatchKeySet(c10::DispatchKeySet(dispatch_key_));
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
bool pass = aoti_kernel_state->second.tensor_checks_[i].check(
|
||||
local_state, inputs[i]);
|
||||
if (!pass) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
kernel_state = aoti_kernel_state->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
void AOTIPythonKernelHolder::cache_hit(
|
||||
const AOTIKernelState& kernel_state,
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
const c10::DispatchKeySet& keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
std::vector<at::Tensor> inputs;
|
||||
unpack_tensors(op.schema().arguments(), *stack, device_, inputs);
|
||||
torch::jit::drop(*stack, op.schema().arguments().size());
|
||||
|
||||
auto outputs = kernel_state.kernel_runner_->run(inputs);
|
||||
for (auto& output : outputs) {
|
||||
stack->push_back(output);
|
||||
}
|
||||
}
|
||||
|
||||
AOTIKernelMetadata AOTIPythonKernelHolder::get_inputs_metadata(
|
||||
const std::vector<at::Tensor>& inputs) {
|
||||
AOTIKernelMetadata inputs_metadata;
|
||||
for (const auto& input : inputs) {
|
||||
auto device = input.device();
|
||||
if (device.is_cpu()) {
|
||||
// If the device is CPU, set the device index to -1.
|
||||
device = c10::Device(device.type(), -1);
|
||||
}
|
||||
|
||||
inputs_metadata.emplace_back(
|
||||
false, // is symbloic
|
||||
input.scalar_type(),
|
||||
device,
|
||||
input.sizes().vec(),
|
||||
input.strides().vec());
|
||||
}
|
||||
return inputs_metadata;
|
||||
}
|
||||
|
||||
void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
|
||||
if (device_.type() == c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES) {
|
||||
return;
|
||||
}
|
||||
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
py::handle load_aoti_eager_cache_function =
|
||||
py::module::import("torch._inductor.utils").attr("load_aoti_eager_cache");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
load_aoti_eager_cache_function.ptr() != nullptr,
|
||||
"Failed to import - torch._inductor.utils.load_aoti_eager_cache");
|
||||
|
||||
auto result = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
|
||||
load_aoti_eager_cache_function.ptr(),
|
||||
py::str(ns_).ptr(),
|
||||
py::str(op_name_with_overload_).ptr(),
|
||||
py::str(c10::DeviceTypeName(device_.type(), true)).ptr(),
|
||||
nullptr));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
result.ptr() != nullptr && result.ptr() != Py_None,
|
||||
"Failed to load AOTI kernel. Operator Name is ",
|
||||
op_name_with_overload_);
|
||||
|
||||
auto kernel_info_list = result.cast<py::list>();
|
||||
for (auto kernel_info : kernel_info_list) {
|
||||
auto item_dict = kernel_info.cast<py::dict>();
|
||||
|
||||
// Access the kernel_path field
|
||||
auto kernel_path = item_dict["kernel_path"].cast<std::string>();
|
||||
|
||||
// Access the meta_info list
|
||||
auto inputs_metadata = item_dict["meta_info"].cast<py::list>();
|
||||
|
||||
std::vector<torch::dynamo::TensorCheck> tensor_checks;
|
||||
std::vector<TensorMetadata> tensor_metadata_list;
|
||||
|
||||
torch::dynamo::LocalState state;
|
||||
// Loop over the meta_info list
|
||||
for (auto item : inputs_metadata) {
|
||||
// Convert the handle to a dict
|
||||
auto metadata = item.cast<py::dict>();
|
||||
|
||||
// Access the fields of each metadata dict
|
||||
auto is_dynamic = metadata["is_dynamic"].cast<bool>();
|
||||
auto device_type = metadata["device_type"].cast<std::string>();
|
||||
auto device_index = metadata["device_index"].cast<int8_t>();
|
||||
auto data_type_obj = metadata["dtype"].cast<py::object>();
|
||||
TORCH_INTERNAL_ASSERT(THPDtype_Check(data_type_obj.ptr()));
|
||||
auto data_type =
|
||||
reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
|
||||
auto sizes = metadata["sizes"].cast<std::vector<int64_t>>();
|
||||
auto strides = metadata["strides"].cast<std::vector<int64_t>>();
|
||||
|
||||
std::vector<std::optional<c10::SymInt>> sym_optional_sizes;
|
||||
std::vector<std::optional<c10::SymInt>> sym_optional_strides;
|
||||
for (int64_t size : sizes) {
|
||||
sym_optional_sizes.push_back(std::optional<c10::SymInt>(size));
|
||||
}
|
||||
for (int64_t stride : strides) {
|
||||
sym_optional_strides.push_back(std::optional<c10::SymInt>(stride));
|
||||
}
|
||||
|
||||
// Now you can use these variables in your code
|
||||
tensor_metadata_list.emplace_back(
|
||||
is_dynamic,
|
||||
data_type,
|
||||
c10::Device(c10::Device(device_type).type(), device_index),
|
||||
sizes,
|
||||
strides);
|
||||
tensor_checks.emplace_back(
|
||||
state,
|
||||
nullptr,
|
||||
uint64_t(c10::DispatchKeySet(dispatch_key_).raw_repr()),
|
||||
data_type,
|
||||
c10::DeviceIndex(device_index),
|
||||
sym_optional_sizes,
|
||||
sym_optional_strides);
|
||||
}
|
||||
|
||||
AOTIKernelState aoti_kernel_state;
|
||||
aoti_kernel_state.kernel_runner_ = load_aoti_model_runner(kernel_path);
|
||||
aoti_kernel_state.tensor_checks_ = tensor_checks;
|
||||
aoti_kernel_cache_[tensor_metadata_list] = aoti_kernel_state;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<AOTIModelContainerRunner> AOTIPythonKernelHolder::
|
||||
load_aoti_model_runner(const std::string& so_path) {
|
||||
if (device_.type() == c10::DeviceType::CUDA) {
|
||||
#ifdef USE_CUDA
|
||||
return std::make_shared<AOTIModelContainerRunnerCpu>(so_path);
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
} else if (device_.type() == c10::DeviceType::CPU) {
|
||||
return std::make_shared<AOTIModelContainerRunnerCpu>(so_path);
|
||||
} else {
|
||||
TORCH_WARN("Unsupported device type");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void AOTIPythonKernelHolder::cache_miss(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
const c10::DispatchKeySet& keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
auto kernel_lib_path = produce_aoti_kernel_lib(op, keyset, stack);
|
||||
std::shared_ptr<AOTIModelContainerRunner> kernel = nullptr;
|
||||
@ -167,7 +342,6 @@ void AOTIPythonKernelHolder::cache_miss(
|
||||
unpack_tensors(op.schema().arguments(), *stack, device_, inputs),
|
||||
"Failed to unpack tensors for the stack to run the AOTI kernel.");
|
||||
auto outputs = kernel->run(inputs);
|
||||
if (outputs.size() > 0) {
|
||||
torch::jit::drop(*stack, op.schema().arguments().size());
|
||||
// TODO: Get the output type of this operation and then convert to the
|
||||
// output type.
|
||||
@ -175,33 +349,34 @@ void AOTIPythonKernelHolder::cache_miss(
|
||||
torch::jit::push(*stack, std::move(output));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
const c10::DispatchKeySet& keyset,
|
||||
const torch::jit::Stack* stack) {
|
||||
auto arguments = torch::jit::last(*stack, op.schema().arguments().size());
|
||||
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
// Get the corresponding python operation for the current operator and the
|
||||
// python operation will pass to the AOT Inductor to generate the kernel
|
||||
// library.
|
||||
const auto& schema = op.schema();
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
const auto& overload_name =
|
||||
schema.overload_name().empty() ? "default" : schema.overload_name();
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
std::string ns_str = qualified_name.substr(0, pos);
|
||||
const char* ns = ns_str.c_str();
|
||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||
std::string ns_str(qualified_name.begin(), qualified_name.begin() + pos);
|
||||
std::string func_name(
|
||||
qualified_name.begin() + pos + strlen("::"), qualified_name.end());
|
||||
|
||||
py::gil_scoped_acquire gil;
|
||||
py::handle op_py_func = op.getPythonOp(pyinterpreter_, [&]() -> PyObject* {
|
||||
py::handle torch_api_function =
|
||||
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
py::handle torch_api_function = py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr(ns_str.c_str())
|
||||
.attr(func_name.c_str());
|
||||
if (overload_name.empty()) {
|
||||
return torch_api_function.attr("default").ptr();
|
||||
} else {
|
||||
return torch_api_function.attr(overload_name.c_str()).ptr();
|
||||
}
|
||||
});
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
@ -212,17 +387,22 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib(
|
||||
overload_name);
|
||||
|
||||
py::handle aot_compile_function =
|
||||
py::module::import("torch._export").attr("aot_compile");
|
||||
py::module::import("torch._inductor.utils")
|
||||
.attr("aoti_compile_with_persistent_cache");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
aot_compile_function.ptr() != nullptr &&
|
||||
aot_compile_function.ptr() != Py_None,
|
||||
"Failed to import - torch._export.aot_compile");
|
||||
"Failed to import - torch._inductor.utils.aoti_compile_with_persistent_cache");
|
||||
|
||||
// Pass the python operation to the AOT Inductor to generate the kernel
|
||||
// library.
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments.vec());
|
||||
auto result = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
|
||||
aot_compile_function.ptr(),
|
||||
py::str(ns_str).ptr(),
|
||||
py::str(op_name_with_overload_).ptr(),
|
||||
py::str(c10::DeviceTypeName(device_.type(), true)).ptr(),
|
||||
py::bool_(false).ptr(),
|
||||
op_py_func.ptr(),
|
||||
args_kwargs.first.ptr(),
|
||||
args_kwargs.second.ptr(),
|
||||
|
@ -4,6 +4,8 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/boxing/KernelFunction.h>
|
||||
|
||||
#include <torch/csrc/dynamo/guards.h>
|
||||
#include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
@ -11,6 +13,11 @@
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
struct AOTIKernelState {
|
||||
std::shared_ptr<AOTIModelContainerRunner> kernel_runner_;
|
||||
std::vector<torch::dynamo::TensorCheck> tensor_checks_;
|
||||
};
|
||||
|
||||
// The AOTIPythonKernelHolder class uses the AOT Inductor to generate a kernel
|
||||
// for a specified operation. To speed up this process, the generated kernel
|
||||
// library is cached on disk. Detailed information from the input tensors is
|
||||
@ -31,6 +38,10 @@ class AOTIPythonKernelHolder : public c10::OperatorKernel {
|
||||
// op_overload_name.
|
||||
c10::impl::PyInterpreter* pyinterpreter_;
|
||||
|
||||
std::
|
||||
unordered_map<AOTIKernelMetadata, AOTIKernelState, AOTIKernelMetadataHash>
|
||||
aoti_kernel_cache_;
|
||||
|
||||
public:
|
||||
AOTIPythonKernelHolder(
|
||||
c10::DispatchKey dispatch_key,
|
||||
@ -45,20 +56,36 @@ class AOTIPythonKernelHolder : public c10::OperatorKernel {
|
||||
private:
|
||||
bool cache_lookup(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack);
|
||||
const c10::DispatchKeySet& keyset,
|
||||
const torch::jit::Stack* stack,
|
||||
AOTIKernelState& kernel_state);
|
||||
void cache_miss(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
const c10::DispatchKeySet& keyset,
|
||||
torch::jit::Stack* stack);
|
||||
void cache_hit(
|
||||
const AOTIKernelState& kernel_state,
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
const c10::DispatchKeySet& keyset,
|
||||
torch::jit::Stack* stack);
|
||||
// Invoke python utility function on the Inductor side to produce AOTI kernel
|
||||
// for the given operation.
|
||||
// Inductor utility function -
|
||||
// torch._inductor.utils.aoti_compile_with_persistent_cache
|
||||
std::string produce_aoti_kernel_lib(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack);
|
||||
const c10::DispatchKeySet& keyset,
|
||||
const torch::jit::Stack* stack);
|
||||
// Invoke python utility function on the Inductor side to load AOTI kernel for
|
||||
// the given operation.
|
||||
// Inductor utility function - torch._inductor.utils.load_aoti_eager_cache
|
||||
void init_aoti_kernel_cache();
|
||||
// Abstract the meta information of each tensor for the given operation. The
|
||||
// meta infomation will be used for cache lookup as the key.
|
||||
AOTIKernelMetadata get_inputs_metadata(const std::vector<at::Tensor>&);
|
||||
// Load the AOTIModelContainerRunner object from the given file path.
|
||||
std::shared_ptr<AOTIModelContainerRunner> load_aoti_model_runner(
|
||||
const std::string&);
|
||||
};
|
||||
|
||||
} // namespace torch::inductor
|
||||
|
64
torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp
Normal file
64
torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp
Normal file
@ -0,0 +1,64 @@
|
||||
#if !defined(C10_MOBILE) && !defined(ANDROID)
|
||||
#include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
TensorMetadata::TensorMetadata(const at::Tensor& src_tensor)
|
||||
: is_symbolic_(false),
|
||||
device_(src_tensor.device()),
|
||||
sizes_(src_tensor.sizes().vec()),
|
||||
strides_(src_tensor.sizes().vec()) {}
|
||||
|
||||
TensorMetadata::TensorMetadata(
|
||||
bool is_symbolic,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device,
|
||||
std::vector<int64_t> sizes,
|
||||
std::vector<int64_t> strides)
|
||||
: is_symbolic_(is_symbolic),
|
||||
dtype_(dtype),
|
||||
device_(device),
|
||||
sizes_(sizes),
|
||||
strides_(strides) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!is_symbolic_, "Not support symbolic shape now");
|
||||
}
|
||||
|
||||
bool TensorMetadata::operator==(const TensorMetadata& other) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!is_symbolic_, "Not support symbolic shape now");
|
||||
return this->is_symbolic_ == other.is_symbolic_ &&
|
||||
this->dtype_ == other.dtype_ &&
|
||||
this->device_.type() == other.device_.type() &&
|
||||
this->sizes_ == other.sizes_ && this->strides_ == other.strides_;
|
||||
}
|
||||
|
||||
size_t TensorMetadataHash::operator()(
|
||||
const TensorMetadata& tensor_metadata) const {
|
||||
auto hash = std::hash<bool>()(tensor_metadata.is_symbolic_);
|
||||
hash = c10::hash_combine(
|
||||
hash, std::hash<c10::ScalarType>()(tensor_metadata.dtype_));
|
||||
hash = c10::hash_combine(
|
||||
hash, std::hash<c10::DeviceType>()(tensor_metadata.device_.type()));
|
||||
|
||||
for (auto& e : tensor_metadata.sizes_) {
|
||||
hash = c10::hash_combine(hash, std::hash<int64_t>()(e));
|
||||
}
|
||||
|
||||
for (auto& e : tensor_metadata.strides_) {
|
||||
hash = c10::hash_combine(hash, std::hash<int64_t>()(e));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
size_t AOTIKernelMetadataHash::operator()(
|
||||
const AOTIKernelMetadata& aoti_kernel_metadata) const {
|
||||
size_t hash = 0;
|
||||
for (auto& e : aoti_kernel_metadata) {
|
||||
hash = c10::hash_combine(hash, TensorMetadataHash()(e));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
} // namespace torch::inductor
|
||||
#endif
|
67
torch/csrc/inductor/aoti_eager/kernel_meta_info.h
Normal file
67
torch/csrc/inductor/aoti_eager/kernel_meta_info.h
Normal file
@ -0,0 +1,67 @@
|
||||
#if !defined(C10_MOBILE) && !defined(ANDROID)
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
// Regarding a aten operation implemented by AOTI, the metadata of the input
|
||||
// tensors will be cached on the disk to acclerate next run. TensorMetada
|
||||
// structure is to represent the metadata of each input tensor. it includes
|
||||
// whether the tensor is symbolic, the dtype, the device, the sizes and the
|
||||
// strides of the tensor. When the metadata of the input tensors is the same as
|
||||
// the cached metadata, the cached kernel library will be loaded and executed.
|
||||
// Otherwise, the AOT Inductor will be called again to generate the kernel
|
||||
// library.
|
||||
// Beyond the TensorMetadata, we build guard/TensorCheck for each input tensor
|
||||
// as well to support symbolic shape. We intend to utilize TensorCheck to find
|
||||
// out the proper kernel rather than TensorMetada comparison. Suppose an
|
||||
// operation with a single input tensor and two kernels:
|
||||
// kernel1: TensorMetadata(is_symbolic=false, dtype=Float, device=CPU,
|
||||
// sizes=[s0, s1, s2], strides=[s1 * s2, s2, 1]) kernel2:
|
||||
// TensorMetadata(is_symbolic=false, dtype=Float, device=CPU, sizes=[3, s1,
|
||||
// s2], strides=[s1 * s2, s2, 1])
|
||||
// If a tensor with sizes=[3, 4, 5] is passed to the operation, both kernel1 and
|
||||
// kernel2 support the tensor shape. In this case, we need to use TensorCheck
|
||||
// plus some heruistic rules to find out the proper kernel.
|
||||
struct TensorMetadata {
|
||||
// Indicate whether the tensor is symbolic and it may be concluded by sizes_
|
||||
// and strides_ in the future.
|
||||
bool is_symbolic_;
|
||||
// Dtype of a tensor(For scalar, we will wrap it as a scalar tensor)
|
||||
c10::ScalarType dtype_;
|
||||
// Device of a tensor.
|
||||
c10::Device device_;
|
||||
// Sizes of a tensor. Currently, we only support static shape and use int64_t
|
||||
// to represent the sizes. In the future, we will create symbolic size and use
|
||||
// SymInt to represent it to support symbolic shape.
|
||||
std::vector<int64_t> sizes_;
|
||||
// Strides of a tensor. For symbolic shape support, it is the same as sizes_
|
||||
std::vector<int64_t> strides_;
|
||||
|
||||
TensorMetadata(const at::Tensor& src_tensor);
|
||||
TensorMetadata(
|
||||
bool is_symbolic,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device,
|
||||
std::vector<int64_t> sizes,
|
||||
std::vector<int64_t> strides);
|
||||
|
||||
bool operator==(const TensorMetadata& other) const;
|
||||
};
|
||||
|
||||
struct TensorMetadataHash {
|
||||
size_t operator()(const TensorMetadata&) const;
|
||||
};
|
||||
|
||||
using AOTIKernelMetadata = std::vector<TensorMetadata>;
|
||||
|
||||
struct AOTIKernelMetadataHash {
|
||||
size_t operator()(const AOTIKernelMetadata&) const;
|
||||
};
|
||||
|
||||
} // namespace torch::inductor
|
||||
#endif
|
Reference in New Issue
Block a user