Fix for AOTI + CUDAGraphs when calling from Python (#148601)

**Background**: I've been comparing performance of torch.compile vs. torch.export + AOTI (specifically, loaded from Python) on the Flux model and found a ~1.4% performance decrease with the latter. The trace shows that CUDAGraphs are not utilized for torch.export + AOTI, leading to higher overhead.

When trying to manually CUDAGraph the loaded, previously exported + AOTIed model (thanks to @eellison for the logic here), I get:
```
Error: operation not permitted when stream is capturing
```

@desertfire confirms that this is due to multi-threading logic on the AOTI runtime side (in `AOTIModelContainer` / `AOTIModel`) conflicting with the use of CUDAGraphs.

**Fix**: This PR takes the approach of providing an alternate, single-threaded method for running loaded models with the AOTI runtime. Details:
* Python side introduces a new flag to enable this behavior (needs a better name): `torch._inductor.package.load_package(..., run_single_threaded=False)`
    * This flag is passed down to the C++ side's `AOTIModelPackageLoader`, which passes it to the `CreateAOTIModelRunnerFunc` during `AOTIModelContainerRunner` construction.
* C++ side introduces single-threaded alternatives to model running and model container running:
    * `AOTIModelContainer.run_single_threaded()` / `AOTIModel.run_single_threaded()`. The interfaces match those of `run()`, but the synchronization logic has been removed.
    * Introduces `AOTInductorModelContainerRunSingleThreaded` to AOTI's `interface.h`; this is invoked by the `AOTIModelContainerRunner` utility class when `run_single_threaded=true`.

I've verified on both a small repro and my real-world use case that I can manually CUDAGraph a loaded model that was previously exported + AOTIed.

**Future work:**
* Flip default value to `run_single_threaded=True` as Python-side inference doesn't take advantage of the AOTI runtime thread pool
    * There are some BC concerns here - models need to be re-serialized so the .so contains the new `AOTInductorModelContainerRunSingleThreaded` interface func. We can flip the default value and warn (instead of crashing) if the `AOTInductorModelContainerRunSingleThreaded` symbol does not exist.
* Compose with cudagraph trees as opposed to manual cuda graph wrapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148601
Approved by: https://github.com/desertfire
This commit is contained in:
Joel Schlosser
2025-03-07 15:58:33 -05:00
committed by PyTorch MergeBot
parent 9f170d9d13
commit 85467ed063
19 changed files with 244 additions and 38 deletions

View File

@ -4566,6 +4566,78 @@ class AOTInductorTestsTemplate:
}
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
def test_with_cudagraphs(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
# define CUDAGraph handling wrapper (only works with kwargs for simplicity)
def cudagraph(f):
_graphs = {}
def f_(**kwargs):
key = hash(
tuple(
tuple(kwargs[a].shape)
for a in sorted(kwargs.keys())
if isinstance(kwargs[a], torch.Tensor)
)
)
if key in _graphs:
wrapped, *_ = _graphs[key]
return wrapped(**kwargs)
g = torch.cuda.CUDAGraph()
in_tensors = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in kwargs.items()
}
f(**in_tensors) # stream warmup
with torch.cuda.graph(g):
out_tensors = f(**in_tensors)
def wrapped(**kwargs):
for key in kwargs:
in_tensors[key].copy_(kwargs[key])
g.replay()
if isinstance(out_tensors, torch.Tensor):
return out_tensors.clone()
elif isinstance(out_tensors, (list, tuple)):
return type(out_tensors)(o.clone() for o in out_tensors)
raise ValueError("unsupported output type encountered")
_graphs[key] = (wrapped, g, in_tensors, out_tensors)
return wrapped(**kwargs)
return f_
# define a simple model
model = torch.nn.Linear(10, 20).to(device=self.device)
# export + AOTI
model_kwargs = {
"input": torch.randn(3, 10, device=self.device),
}
ep = torch.export.export(model, args=(), kwargs=model_kwargs)
optimized = torch._inductor.aoti_load_package(
torch._inductor.aoti_compile_and_package(
ep,
inductor_configs={"max_autotune": True},
),
# NB: this flag avoids a CUDAGraph + AOTI runtime multi-threading conflict
# "Error: operation not permitted when stream is capturing"
run_single_threaded=True,
)
# enable CUDAGraphs
optimized = cudagraph(optimized)
# warmup -> run with CUDAGraphs
for _ in range(3):
optimized(**model_kwargs)
# compare against eager
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
class AOTInductorLoggingTest(LoggingTestCase):
@make_logging_test(dynamic=logging.DEBUG)

View File

@ -232,7 +232,7 @@ def _aoti_compile_and_package_inner(
return package_path
def aoti_load_package(path: FileLike) -> Any: # type: ignore[type-arg]
def aoti_load_package(path: FileLike, run_single_threaded: bool = False) -> Any: # type: ignore[type-arg]
"""
Loads the model from the PT2 package.
@ -248,10 +248,13 @@ def aoti_load_package(path: FileLike) -> Any: # type: ignore[type-arg]
Args:
path: Path to the .pt2 package
run_single_threaded (bool): Whether the model should be run without
thread synchronization logic. This is useful to avoid conflicts with
CUDAGraphs.
"""
from torch._inductor.package import load_package
return load_package(path)
return load_package(path, run_single_threaded=run_single_threaded)
def aot_compile(

View File

@ -117,6 +117,33 @@ AOTIRuntimeError AOTInductorModelContainerRun(
})
}
AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded(
AOTInductorModelContainerHandle container_handle,
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
size_t num_inputs,
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
size_t num_outputs,
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
auto stream =
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
AOTINoGradGuard guard;
container->run_single_threaded(
input_handles, output_handles, stream, proxy_executor_handle);
})
}
AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
AOTInductorModelContainerHandle container_handle,
size_t* num_constants) {

View File

@ -274,7 +274,9 @@ class AOTICompiledModel:
return AOTICompiledModel(self.loader) # type: ignore[attr-defined]
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
def load_package(
path: FileLike, model_name: str = "model", run_single_threaded: bool = False
) -> AOTICompiledModel: # type: ignore[type-arg]
assert (
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), (
@ -288,9 +290,13 @@ def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel
f.write(path.read())
path.seek(0)
log.debug("Writing buffer to tmp file located at %s.", f.name)
loader = torch._C._aoti.AOTIModelPackageLoader(f.name, model_name) # type: ignore[call-arg]
loader = torch._C._aoti.AOTIModelPackageLoader(
f.name, model_name, run_single_threaded
) # type: ignore[call-arg]
return AOTICompiledModel(loader)
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg]
loader = torch._C._aoti.AOTIModelPackageLoader(
path, model_name, run_single_threaded
) # type: ignore[call-arg]
return AOTICompiledModel(loader)

View File

@ -437,7 +437,8 @@ std::shared_ptr<AOTIModelContainerRunner> AOTIPythonKernelHolder::
return std::make_shared<AOTIModelContainerRunnerCpu>(so_path);
} else {
auto aoti_model_runer_fn = registered_aoti_runner[device_name];
return aoti_model_runer_fn(so_path, 1, device_name, "");
return aoti_model_runer_fn(
so_path, 1, device_name, "", /*run_single_threaded=*/false);
}
}

View File

@ -339,12 +339,15 @@ void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) {
}
AOTIModelPackageLoader::AOTIModelPackageLoader(
const std::string& model_package_path)
: AOTIModelPackageLoader(model_package_path, "model") {}
const std::string& model_package_path,
const bool run_single_threaded = false)
: AOTIModelPackageLoader(model_package_path, "model", run_single_threaded) {
}
AOTIModelPackageLoader::AOTIModelPackageLoader(
const std::string& model_package_path,
const std::string& model_name = "model") {
const std::string& model_name = "model",
const bool run_single_threaded = false) {
// Extract all files within the zipfile to a temporary directory
mz_zip_archive zip_archive;
memset(&zip_archive, 0, sizeof(zip_archive));
@ -457,7 +460,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
}
std::string cubin_dir = temp_dir_ + k_separator + model_directory;
runner_ = registered_aoti_runner[device](so_path, 1, device, cubin_dir);
runner_ = registered_aoti_runner[device](
so_path, 1, device, cubin_dir, run_single_threaded);
}
AOTIModelPackageLoader::~AOTIModelPackageLoader() {

View File

@ -7,10 +7,13 @@
namespace torch::inductor {
class TORCH_API AOTIModelPackageLoader {
public:
AOTIModelPackageLoader(const std::string& model_package_path);
AOTIModelPackageLoader(
const std::string& model_package_path,
const std::string& model_name);
const bool run_single_threaded);
AOTIModelPackageLoader(
const std::string& model_package_path,
const std::string& model_name,
const bool run_single_threaded);
~AOTIModelPackageLoader();
AOTIModelContainerRunner* get_runner();

View File

@ -13,13 +13,19 @@ namespace torch::inductor {
class AOTIModelPackageLoaderPybind : public AOTIModelPackageLoader {
public:
AOTIModelPackageLoaderPybind(const std::string& model_package_path)
: AOTIModelPackageLoader(model_package_path) {}
AOTIModelPackageLoaderPybind(
const std::string& model_package_path,
const bool run_single_threaded)
: AOTIModelPackageLoader(model_package_path, run_single_threaded) {}
AOTIModelPackageLoaderPybind(
const std::string& model_package_path,
const std::string& model_name)
: AOTIModelPackageLoader(model_package_path, model_name) {}
const std::string& model_name,
const bool run_single_threaded)
: AOTIModelPackageLoader(
model_package_path,
model_name,
run_single_threaded) {}
py::list boxed_run(py::list& inputs, void* stream_handle = nullptr) {
std::vector<at::Tensor> input_tensors;
@ -47,8 +53,8 @@ void initAOTIPackageBindings(PyObject* module) {
auto m = rootModule.def_submodule("_aoti");
py::class_<AOTIModelPackageLoaderPybind>(m, "AOTIModelPackageLoader")
.def(py::init<const std::string&, const std::string&>())
.def(py::init<const std::string&>())
.def(py::init<const std::string&, const std::string&, const bool>())
.def(py::init<const std::string&, const bool>())
.def("get_metadata", &AOTIModelPackageLoaderPybind::get_metadata)
.def(
"run",

View File

@ -29,7 +29,8 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir) {
const std::string& cubin_dir,
const bool run_single_threaded) {
model_so_ = std::make_unique<at::DynamicLibrary>(model_so_path.c_str());
TORCH_CHECK(model_so_, "Failed to load model: ", model_so_path);
create_func_ = reinterpret_cast<decltype(create_func_)>(
@ -38,8 +39,9 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
model_so_->sym("AOTInductorModelContainerDelete"));
get_num_outputs_func_ = reinterpret_cast<decltype(get_num_outputs_func_)>(
model_so_->sym("AOTInductorModelContainerGetNumOutputs"));
run_func_ = reinterpret_cast<decltype(run_func_)>(
model_so_->sym("AOTInductorModelContainerRun"));
run_func_ = reinterpret_cast<decltype(run_func_)>(model_so_->sym(
run_single_threaded ? "AOTInductorModelContainerRunSingleThreaded"
: "AOTInductorModelContainerRun"));
get_num_constants_func_ = reinterpret_cast<decltype(get_num_constants_func_)>(
model_so_->sym("AOTInductorModelContainerGetNumConstants"));
get_constant_name_func_ = reinterpret_cast<decltype(get_constant_name_func_)>(

View File

@ -58,7 +58,8 @@ class TORCH_API AOTIModelContainerRunner {
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir);
const std::string& cubin_dir,
const bool run_single_threaded);
virtual std::vector<at::Tensor> run_impl(
std::vector<AtenTensorHandle>& input_handles,
@ -100,7 +101,8 @@ using CreateAOTIModelRunnerFunc = std::unique_ptr<AOTIModelContainerRunner> (*)(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& bin_dir);
const std::string& bin_dir,
const bool run_single_threaded);
// Return a global map "device name" -> "aoti model runner create function" for
// all registered in AOTI external backends

View File

@ -7,8 +7,14 @@ namespace torch::inductor {
// We provide NO BC guarantee for these APIs
AOTIModelContainerRunnerCpu::AOTIModelContainerRunnerCpu(
const std::string& model_so_path,
size_t num_models)
: AOTIModelContainerRunner(model_so_path, num_models, "cpu", "") {}
size_t num_models,
bool run_single_threaded)
: AOTIModelContainerRunner(
model_so_path,
num_models,
"cpu",
"",
run_single_threaded) {}
AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() = default;
@ -17,12 +23,13 @@ std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_cpu(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir) {
const std::string& cubin_dir,
const bool run_single_threaded) {
if (device_str != "cpu") {
throw std::runtime_error("Incorrect device passed to aoti_runner_cpu");
}
return std::make_unique<AOTIModelContainerRunnerCpu>(
model_so_path, num_models);
model_so_path, num_models, run_single_threaded);
}
} // namespace

View File

@ -8,7 +8,8 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner {
public:
AOTIModelContainerRunnerCpu(
const std::string& model_so_path,
size_t num_models = 1);
size_t num_models = 1,
const bool run_single_threaded = false);
~AOTIModelContainerRunnerCpu() override;
};

View File

@ -7,12 +7,14 @@ AOTIModelContainerRunnerCuda::AOTIModelContainerRunnerCuda(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir)
const std::string& cubin_dir,
const bool run_single_threaded)
: AOTIModelContainerRunner(
model_so_path,
num_models,
device_str,
cubin_dir) {}
cubin_dir,
run_single_threaded) {}
AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() = default;
@ -37,9 +39,10 @@ std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_cuda(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir) {
const std::string& cubin_dir,
const bool run_single_threaded) {
return std::make_unique<AOTIModelContainerRunnerCuda>(
model_so_path, num_models, device_str, cubin_dir);
model_so_path, num_models, device_str, cubin_dir, run_single_threaded);
}
} // namespace

View File

@ -17,7 +17,8 @@ class TORCH_CUDA_CPP_API AOTIModelContainerRunnerCuda
const std::string& model_so_path,
size_t num_models = 1,
const std::string& device_str = "cuda",
const std::string& cubin_dir = "");
const std::string& cubin_dir = "",
const bool run_single_threaded = false);
~AOTIModelContainerRunnerCuda() override;

View File

@ -7,12 +7,14 @@ AOTIModelContainerRunnerXpu::AOTIModelContainerRunnerXpu(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& kernel_bin_dir)
const std::string& kernel_bin_dir,
const bool run_single_threaded)
: AOTIModelContainerRunner(
model_so_path,
num_models,
device_str,
kernel_bin_dir) {}
kernel_bin_dir,
run_single_threaded) {}
AOTIModelContainerRunnerXpu::~AOTIModelContainerRunnerXpu() = default;
@ -37,9 +39,14 @@ std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_xpu(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& kernel_bin_dir) {
const std::string& kernel_bin_dir,
const bool run_single_threaded) {
return std::make_unique<AOTIModelContainerRunnerXpu>(
model_so_path, num_models, device_str, kernel_bin_dir);
model_so_path,
num_models,
device_str,
kernel_bin_dir,
run_single_threaded);
}
} // namespace

View File

@ -19,7 +19,8 @@ class C10_EXPORT AOTIModelContainerRunnerXpu : public AOTIModelContainerRunner {
const std::string& model_so_path,
size_t num_models = 1,
const std::string& device_str = "xpu",
const std::string& kernel_bin_dir = "");
const std::string& kernel_bin_dir = "",
const bool run_single_threaded = false);
~AOTIModelContainerRunnerXpu() override;

View File

@ -58,6 +58,20 @@ AOTIRuntimeError AOTInductorModelContainerRun(
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle);
// Single-threaded variant of previous.
AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded(
AOTInductorModelContainerHandle container_handle,
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
size_t num_inputs,
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
size_t num_outputs,
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle);
// Retrieves the number of constants for the model.
AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
AOTInductorModelContainerHandle container_handle,

View File

@ -231,6 +231,24 @@ class AOTInductorModelBase {
#endif // USE_CUDA
}
// Non-thread-aware variant of run(). Obviously unsafe to use in a threaded
// environment :)
void run_single_threaded(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor) {
// don't bother with any of the run_finished stuff; this is unsafe to call
// in a threaded context
auto* model = static_cast<Model*>(this);
model->run_impl(input_handles, output_handles, stream, proxy_executor);
}
std::unordered_map<std::string, AtenTensorHandle> run_const_fold(
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor,

View File

@ -112,6 +112,34 @@ class AOTInductorModelContainer {
pending_models_available_.notify_one();
}
// Non-thread-aware variant of run(). Obviously unsafe to use in a threaded
// environment :)
void run_single_threaded(
AtenTensorHandle*
input_handles, // array of input AtenTensorHandle; handles
// are stolen; the array itself is borrowed
AtenTensorHandle*
output_handles, // array for writing output AtenTensorHandle; handles
// will be stolen by the caller; the array itself is
// borrowed
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor) {
auto* model = available_models_[0];
if (!constant_folded_) {
auto folded_const_map = model->run_const_fold(
stream, proxy_executor, /* initialization = */ true);
update_constant_buffer(
std::move(folded_const_map),
/* use_inactive = */ false,
/* validate_full_update = */ false);
constant_folded_ = true;
}
model->run_single_threaded(
input_handles, output_handles, stream, proxy_executor);
}
size_t num_constants() const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");