mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Add pybind for AOTIModelContainerRunnerCpu and AOTIModelContainerRunnerCuda (#116269)
Summary: Now we can allocate an AOTIModelContainerRunner object instead of relying on torch.utils.cpp_extension.load_inline. Also renamed AOTInductorModelRunner to AOTIRunnerUtil in this PR. Test Plan: CI Reviewed By: khabinov Differential Revision: D52339116 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116269 Approved by: https://github.com/khabinov
This commit is contained in:
committed by
PyTorch MergeBot
parent
56d7a47806
commit
70f3a530d7
@ -807,6 +807,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||
"torch/csrc/jit/backends/backend_init.cpp",
|
||||
"torch/csrc/jit/python/init.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
|
@ -29,9 +29,7 @@ def load_test_module(name):
|
||||
).load_module()
|
||||
|
||||
|
||||
AOTInductorModelRunner = load_test_module(
|
||||
"inductor.test_aot_inductor"
|
||||
).AOTInductorModelRunner
|
||||
AOTIRunnerUtil = load_test_module("inductor.test_aot_inductor_utils").AOTIRunnerUtil
|
||||
|
||||
import sys
|
||||
|
||||
@ -277,7 +275,7 @@ class C10DFunctionalNativeTest(MultiProcessTestCase):
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
# Test aoti
|
||||
out = AOTInductorModelRunner.run("cuda", func, (arg,))
|
||||
out = AOTIRunnerUtil.run("cuda", func, (arg,))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@ -332,7 +330,7 @@ class C10DFunctionalNativeTest(MultiProcessTestCase):
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
# Test aoti
|
||||
out = AOTInductorModelRunner.run("cuda", func, (args,))
|
||||
out = AOTIRunnerUtil.run("cuda", func, (args,))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@ -406,7 +404,7 @@ class C10DFunctionalNativeTest(MultiProcessTestCase):
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
# Test aoti
|
||||
out = AOTInductorModelRunner.run("cuda", func, (arg,))
|
||||
out = AOTIRunnerUtil.run("cuda", func, (arg,))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@ -449,7 +447,7 @@ class C10DFunctionalNativeTest(MultiProcessTestCase):
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
# Test aoti
|
||||
out = AOTInductorModelRunner.run("cuda", func, (args,))
|
||||
out = AOTIRunnerUtil.run("cuda", func, (args,))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@ -483,7 +481,7 @@ class C10DFunctionalNativeTest(MultiProcessTestCase):
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
# Test aoti
|
||||
out = AOTInductorModelRunner.run("cuda", func, (arg,))
|
||||
out = AOTIRunnerUtil.run("cuda", func, (arg,))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@ -525,7 +523,7 @@ class C10DFunctionalNativeTest(MultiProcessTestCase):
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
# Test aoti
|
||||
out = AOTInductorModelRunner.run("cuda", func, (args,))
|
||||
out = AOTIRunnerUtil.run("cuda", func, (args,))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
|
@ -49,10 +49,10 @@ if IS_WINDOWS and IS_CI:
|
||||
|
||||
try:
|
||||
try:
|
||||
from .test_aot_inductor_utils import AOTInductorModelRunner
|
||||
from .test_aot_inductor_utils import AOTIRunnerUtil
|
||||
from .test_torchinductor import copy_tests, requires_multigpu, TestFailure
|
||||
except ImportError:
|
||||
from test_aot_inductor_utils import AOTInductorModelRunner
|
||||
from test_aot_inductor_utils import AOTIRunnerUtil
|
||||
from test_torchinductor import copy_tests, requires_multigpu, TestFailure
|
||||
except (unittest.SkipTest, ImportError) as e:
|
||||
if __name__ == "__main__":
|
||||
@ -82,7 +82,7 @@ def check_model(
|
||||
expected = ref_model(*ref_inputs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
actual = AOTInductorModelRunner.run(
|
||||
actual = AOTIRunnerUtil.run(
|
||||
self.device,
|
||||
model,
|
||||
example_inputs,
|
||||
@ -114,7 +114,7 @@ def check_model_with_multiple_inputs(
|
||||
list_expected = [ref_model(*inputs) for inputs in ref_inputs]
|
||||
|
||||
torch.manual_seed(0)
|
||||
list_actual = AOTInductorModelRunner.run_multiple(
|
||||
list_actual = AOTIRunnerUtil.run_multiple(
|
||||
self.device, model, list_example_inputs, options, constraints
|
||||
)
|
||||
|
||||
@ -181,7 +181,7 @@ class AOTInductorTestsTemplate:
|
||||
torch.randn(10, 10, device=self.device),
|
||||
)
|
||||
expected_path = os.path.join(tempfile.mkdtemp(dir=cache_dir()), "model.so")
|
||||
actual_path = AOTInductorModelRunner.compile(
|
||||
actual_path = AOTIRunnerUtil.compile(
|
||||
model, example_inputs, options={"aot_inductor.output_path": expected_path}
|
||||
)
|
||||
self.assertTrue(actual_path == expected_path)
|
||||
@ -788,7 +788,7 @@ class AOTInductorTestsTemplate:
|
||||
with torch.cuda.device(0), config.patch(
|
||||
"aot_inductor.abi_compatible", self.abi_compatible
|
||||
):
|
||||
so_path = AOTInductorModelRunner.compile(
|
||||
so_path = AOTIRunnerUtil.compile(
|
||||
model=Model(w1.cuda(0), w2.cuda(0)),
|
||||
example_inputs=tuple(t.cuda(0) for t in inputs),
|
||||
)
|
||||
@ -797,7 +797,7 @@ class AOTInductorTestsTemplate:
|
||||
for i in range(torch.cuda.device_count()):
|
||||
with torch.cuda.device(i):
|
||||
example_inputs = tuple(t.cuda(i) for t in inputs)
|
||||
optimized = AOTInductorModelRunner.load("cuda", so_path, example_inputs)
|
||||
optimized = AOTIRunnerUtil.load("cuda", so_path)
|
||||
result_cuda = optimized(example_inputs)
|
||||
self.assertTrue(same(result_cpu, result_cuda.cpu()))
|
||||
|
||||
@ -837,14 +837,14 @@ class AOTInductorTestsTemplate:
|
||||
with torch.cuda.device(0), torch.no_grad(), config.patch(
|
||||
"aot_inductor.abi_compatible", self.abi_compatible
|
||||
):
|
||||
result_cuda_0 = AOTInductorModelRunner.run(
|
||||
result_cuda_0 = AOTIRunnerUtil.run(
|
||||
"cuda", Model(weight.cuda(0)), tuple(t.cuda(0) for t in inputs)
|
||||
)
|
||||
|
||||
with torch.cuda.device(1), torch.no_grad(), config.patch(
|
||||
"aot_inductor.abi_compatible", self.abi_compatible
|
||||
):
|
||||
result_cuda_1 = AOTInductorModelRunner.run(
|
||||
result_cuda_1 = AOTIRunnerUtil.run(
|
||||
"cuda", Model(weight.cuda(1)), tuple(t.cuda(1) for t in inputs)
|
||||
)
|
||||
|
||||
@ -1006,12 +1006,12 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
# compiler under no_grad
|
||||
with torch.no_grad():
|
||||
so_path = AOTInductorModelRunner.compile(m, example_inputs)
|
||||
so_path = AOTIRunnerUtil.compile(m, example_inputs)
|
||||
|
||||
# run under grad enabled
|
||||
self.assertTrue(torch.is_grad_enabled())
|
||||
|
||||
optimized = AOTInductorModelRunner.load(self.device, so_path, example_inputs)
|
||||
optimized = AOTIRunnerUtil.load(self.device, so_path)
|
||||
actual = optimized(example_inputs)
|
||||
actual = pytree.tree_leaves(actual)
|
||||
|
||||
|
@ -1,18 +1,16 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch._export
|
||||
import torch._inductor
|
||||
import torch.fx._pytree as fx_pytree
|
||||
from torch._inductor.utils import aot_inductor_launcher, cache_dir
|
||||
|
||||
from torch.testing._internal.common_utils import IS_FBCODE
|
||||
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
class AOTInductorModelRunner:
|
||||
class AOTIRunnerUtil:
|
||||
@classmethod
|
||||
def compile(
|
||||
cls,
|
||||
@ -34,41 +32,27 @@ class AOTInductorModelRunner:
|
||||
return so_path
|
||||
|
||||
@classmethod
|
||||
def load(cls, device, so_path, example_inputs):
|
||||
def load(cls, device, so_path):
|
||||
if IS_FBCODE:
|
||||
from .fb import test_aot_inductor_model_runner_pybind
|
||||
|
||||
module = test_aot_inductor_model_runner_pybind.Runner(
|
||||
runner = test_aot_inductor_model_runner_pybind.Runner(
|
||||
so_path, device == "cpu"
|
||||
)
|
||||
|
||||
call_spec = module.get_call_spec()
|
||||
in_spec = pytree.treespec_loads(call_spec[0])
|
||||
out_spec = pytree.treespec_loads(call_spec[1])
|
||||
|
||||
def optimized(*args):
|
||||
flat_inputs = fx_pytree.tree_flatten_spec((*args, {}), in_spec)
|
||||
flat_outputs = module.run(flat_inputs)
|
||||
return pytree.tree_unflatten(flat_outputs, out_spec)
|
||||
|
||||
else:
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name="aot_inductor",
|
||||
cpp_sources=[aot_inductor_launcher(so_path, device)],
|
||||
# use a unique build directory to avoid test interference
|
||||
build_directory=tempfile.mkdtemp(dir=cache_dir()),
|
||||
functions=["run", "get_call_spec"],
|
||||
with_cuda=(device == "cuda"),
|
||||
runner = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
|
||||
if device == "cpu"
|
||||
else torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1)
|
||||
)
|
||||
|
||||
call_spec = module.get_call_spec()
|
||||
def optimized(*args):
|
||||
call_spec = runner.get_call_spec()
|
||||
in_spec = pytree.treespec_loads(call_spec[0])
|
||||
out_spec = pytree.treespec_loads(call_spec[1])
|
||||
|
||||
def optimized(*args):
|
||||
flat_inputs = fx_pytree.tree_flatten_spec((*args, {}), in_spec)
|
||||
flat_outputs = module.run(flat_inputs)
|
||||
return pytree.tree_unflatten(flat_outputs, out_spec)
|
||||
flat_inputs = fx_pytree.tree_flatten_spec((*args, {}), in_spec)
|
||||
flat_outputs = runner.run(flat_inputs)
|
||||
return pytree.tree_unflatten(flat_outputs, out_spec)
|
||||
|
||||
return optimized
|
||||
|
||||
@ -82,14 +66,14 @@ class AOTInductorModelRunner:
|
||||
constraints=None,
|
||||
disable_constraint_solver=False,
|
||||
):
|
||||
so_path = AOTInductorModelRunner.compile(
|
||||
so_path = AOTIRunnerUtil.compile(
|
||||
model,
|
||||
example_inputs,
|
||||
options=options,
|
||||
constraints=constraints,
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
)
|
||||
optimized = AOTInductorModelRunner.load(device, so_path, example_inputs)
|
||||
optimized = AOTIRunnerUtil.load(device, so_path)
|
||||
return optimized(example_inputs)
|
||||
|
||||
@classmethod
|
||||
@ -101,13 +85,13 @@ class AOTInductorModelRunner:
|
||||
options=None,
|
||||
constraints=None,
|
||||
):
|
||||
so_path = AOTInductorModelRunner.compile(
|
||||
so_path = AOTIRunnerUtil.compile(
|
||||
model,
|
||||
list_example_inputs[0],
|
||||
options=options,
|
||||
constraints=constraints,
|
||||
)
|
||||
optimized = AOTInductorModelRunner.load(device, so_path, list_example_inputs[0])
|
||||
optimized = AOTIRunnerUtil.load(device, so_path)
|
||||
list_output_tensors = []
|
||||
for example_inputs in list_example_inputs:
|
||||
list_output_tensors.append(optimized(example_inputs))
|
||||
|
@ -80,7 +80,7 @@ class TestMemoryPlanning(TestCase):
|
||||
|
||||
@skipIfRocm(msg="test_aot_inductor doesn't work on ROCm")
|
||||
def test_abi_compatible(self):
|
||||
from test_aot_inductor import AOTInductorModelRunner
|
||||
from test_aot_inductor import AOTIRunnerUtil
|
||||
|
||||
f, args = self._generate(device="cuda")
|
||||
constraints: List[torch.export.Constraint] = [
|
||||
@ -89,9 +89,7 @@ class TestMemoryPlanning(TestCase):
|
||||
]
|
||||
with config.patch("aot_inductor.abi_compatible", True):
|
||||
result, code = run_and_get_cpp_code(
|
||||
lambda: AOTInductorModelRunner.run(
|
||||
"cuda", f, args, constraints=constraints
|
||||
)
|
||||
lambda: AOTIRunnerUtil.run("cuda", f, args, constraints=constraints)
|
||||
)
|
||||
|
||||
FileCheck().check(
|
||||
|
@ -56,7 +56,7 @@ from torch._prims_common import DeviceLikeType
|
||||
|
||||
# This module is defined in torch/csrc/Module.cpp
|
||||
|
||||
from . import _functorch, _lazy, _lazy_ts_backend, _nn, _onnx, _VariableFunctions, _cpu
|
||||
from . import _functorch, _lazy, _lazy_ts_backend, _nn, _onnx, _VariableFunctions, _cpu, _aoti
|
||||
|
||||
K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
|
3
torch/_C/_aoti.pyi
Normal file
3
torch/_C/_aoti.pyi
Normal file
@ -0,0 +1,3 @@
|
||||
# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
|
||||
class AOTIModelContainerRunnerCpu: ...
|
||||
class AOTIModelContainerRunnerCuda: ...
|
@ -60,6 +60,7 @@
|
||||
#include <torch/csrc/cpu/Module.h>
|
||||
#include <torch/csrc/dynamo/init.h>
|
||||
#include <torch/csrc/functorch/init.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/pybind.h>
|
||||
#include <torch/csrc/jit/python/init.h>
|
||||
#include <torch/csrc/jit/python/python_ir.h>
|
||||
#include <torch/csrc/jit/python/python_tracer.h>
|
||||
@ -1513,6 +1514,7 @@ PyObject* initModule() {
|
||||
torch::profiler::initPythonBindings(module);
|
||||
torch::python::init_bindings(module);
|
||||
torch::lazy::initLazyBindings(module);
|
||||
torch::inductor::initAOTIRunnerBindings(module);
|
||||
#ifdef USE_ITT
|
||||
torch::profiler::initIttBindings(module);
|
||||
#endif
|
||||
|
@ -13,14 +13,19 @@ class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner {
|
||||
const std::string& cubin_dir = "")
|
||||
: AOTIModelContainerRunner(model_so_path, num_models, false, cubin_dir) {}
|
||||
|
||||
std::vector<at::Tensor> run(
|
||||
std::vector<at::Tensor>& inputs,
|
||||
cudaStream_t cuda_stream_handle = nullptr) {
|
||||
if (cuda_stream_handle == nullptr) {
|
||||
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
|
||||
}
|
||||
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) {
|
||||
at::cuda::CUDAStream cuda_stream = c10::cuda::getCurrentCUDAStream();
|
||||
return AOTIModelContainerRunner::run(
|
||||
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
|
||||
inputs,
|
||||
reinterpret_cast<AOTInductorStreamHandle>(cuda_stream.stream()));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> run_with_cuda_stream(
|
||||
std::vector<at::Tensor>& inputs,
|
||||
at::cuda::CUDAStream cuda_stream) {
|
||||
return AOTIModelContainerRunner::run(
|
||||
inputs,
|
||||
reinterpret_cast<AOTInductorStreamHandle>(cuda_stream.stream()));
|
||||
}
|
||||
};
|
||||
|
||||
|
26
torch/csrc/inductor/aoti_runner/pybind.cpp
Normal file
26
torch/csrc/inductor/aoti_runner/pybind.cpp
Normal file
@ -0,0 +1,26 @@
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
|
||||
#ifdef USE_CUDA
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
||||
#endif
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
void initAOTIRunnerBindings(PyObject* module) {
|
||||
auto rootModule = py::handle(module).cast<py::module>();
|
||||
auto m = rootModule.def_submodule("_aoti");
|
||||
|
||||
py::class_<AOTIModelContainerRunnerCpu>(m, "AOTIModelContainerRunnerCpu")
|
||||
.def(py::init<const std::string&, int>())
|
||||
.def("run", &AOTIModelContainerRunnerCpu::run)
|
||||
.def("get_call_spec", &AOTIModelContainerRunnerCpu::get_call_spec);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
py::class_<AOTIModelContainerRunnerCuda>(m, "AOTIModelContainerRunnerCuda")
|
||||
.def(py::init<const std::string&, int>())
|
||||
.def("run", &AOTIModelContainerRunnerCuda::run)
|
||||
.def("get_call_spec", &AOTIModelContainerRunnerCuda::get_call_spec);
|
||||
#endif
|
||||
}
|
||||
} // namespace torch::inductor
|
7
torch/csrc/inductor/aoti_runner/pybind.h
Normal file
7
torch/csrc/inductor/aoti_runner/pybind.h
Normal file
@ -0,0 +1,7 @@
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
void initAOTIRunnerBindings(PyObject* module);
|
||||
|
||||
} // namespace torch::inductor
|
Reference in New Issue
Block a user