mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix for AOTI + CUDAGraphs when calling from Python (#148601)
**Background**: I've been comparing performance of torch.compile vs. torch.export + AOTI (specifically, loaded from Python) on the Flux model and found a ~1.4% performance decrease with the latter. The trace shows that CUDAGraphs are not utilized for torch.export + AOTI, leading to higher overhead. When trying to manually CUDAGraph the loaded, previously exported + AOTIed model (thanks to @eellison for the logic here), I get: ``` Error: operation not permitted when stream is capturing ``` @desertfire confirms that this is due to multi-threading logic on the AOTI runtime side (in `AOTIModelContainer` / `AOTIModel`) conflicting with the use of CUDAGraphs. **Fix**: This PR takes the approach of providing an alternate, single-threaded method for running loaded models with the AOTI runtime. Details: * Python side introduces a new flag to enable this behavior (needs a better name): `torch._inductor.package.load_package(..., run_single_threaded=False)` * This flag is passed down to the C++ side's `AOTIModelPackageLoader`, which passes it to the `CreateAOTIModelRunnerFunc` during `AOTIModelContainerRunner` construction. * C++ side introduces single-threaded alternatives to model running and model container running: * `AOTIModelContainer.run_single_threaded()` / `AOTIModel.run_single_threaded()`. The interfaces match those of `run()`, but the synchronization logic has been removed. * Introduces `AOTInductorModelContainerRunSingleThreaded` to AOTI's `interface.h`; this is invoked by the `AOTIModelContainerRunner` utility class when `run_single_threaded=true`. I've verified on both a small repro and my real-world use case that I can manually CUDAGraph a loaded model that was previously exported + AOTIed. **Future work:** * Flip default value to `run_single_threaded=True` as Python-side inference doesn't take advantage of the AOTI runtime thread pool * There are some BC concerns here - models need to be re-serialized so the .so contains the new `AOTInductorModelContainerRunSingleThreaded` interface func. We can flip the default value and warn (instead of crashing) if the `AOTInductorModelContainerRunSingleThreaded` symbol does not exist. * Compose with cudagraph trees as opposed to manual cuda graph wrapping Pull Request resolved: https://github.com/pytorch/pytorch/pull/148601 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
9f170d9d13
commit
85467ed063
@ -4566,6 +4566,78 @@ class AOTInductorTestsTemplate:
|
||||
}
|
||||
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_with_cudagraphs(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
|
||||
# define CUDAGraph handling wrapper (only works with kwargs for simplicity)
|
||||
def cudagraph(f):
|
||||
_graphs = {}
|
||||
|
||||
def f_(**kwargs):
|
||||
key = hash(
|
||||
tuple(
|
||||
tuple(kwargs[a].shape)
|
||||
for a in sorted(kwargs.keys())
|
||||
if isinstance(kwargs[a], torch.Tensor)
|
||||
)
|
||||
)
|
||||
if key in _graphs:
|
||||
wrapped, *_ = _graphs[key]
|
||||
return wrapped(**kwargs)
|
||||
g = torch.cuda.CUDAGraph()
|
||||
in_tensors = {
|
||||
k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
f(**in_tensors) # stream warmup
|
||||
with torch.cuda.graph(g):
|
||||
out_tensors = f(**in_tensors)
|
||||
|
||||
def wrapped(**kwargs):
|
||||
for key in kwargs:
|
||||
in_tensors[key].copy_(kwargs[key])
|
||||
g.replay()
|
||||
if isinstance(out_tensors, torch.Tensor):
|
||||
return out_tensors.clone()
|
||||
elif isinstance(out_tensors, (list, tuple)):
|
||||
return type(out_tensors)(o.clone() for o in out_tensors)
|
||||
raise ValueError("unsupported output type encountered")
|
||||
|
||||
_graphs[key] = (wrapped, g, in_tensors, out_tensors)
|
||||
return wrapped(**kwargs)
|
||||
|
||||
return f_
|
||||
|
||||
# define a simple model
|
||||
model = torch.nn.Linear(10, 20).to(device=self.device)
|
||||
|
||||
# export + AOTI
|
||||
model_kwargs = {
|
||||
"input": torch.randn(3, 10, device=self.device),
|
||||
}
|
||||
ep = torch.export.export(model, args=(), kwargs=model_kwargs)
|
||||
|
||||
optimized = torch._inductor.aoti_load_package(
|
||||
torch._inductor.aoti_compile_and_package(
|
||||
ep,
|
||||
inductor_configs={"max_autotune": True},
|
||||
),
|
||||
# NB: this flag avoids a CUDAGraph + AOTI runtime multi-threading conflict
|
||||
# "Error: operation not permitted when stream is capturing"
|
||||
run_single_threaded=True,
|
||||
)
|
||||
|
||||
# enable CUDAGraphs
|
||||
optimized = cudagraph(optimized)
|
||||
|
||||
# warmup -> run with CUDAGraphs
|
||||
for _ in range(3):
|
||||
optimized(**model_kwargs)
|
||||
|
||||
# compare against eager
|
||||
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
|
||||
|
||||
|
||||
class AOTInductorLoggingTest(LoggingTestCase):
|
||||
@make_logging_test(dynamic=logging.DEBUG)
|
||||
|
@ -232,7 +232,7 @@ def _aoti_compile_and_package_inner(
|
||||
return package_path
|
||||
|
||||
|
||||
def aoti_load_package(path: FileLike) -> Any: # type: ignore[type-arg]
|
||||
def aoti_load_package(path: FileLike, run_single_threaded: bool = False) -> Any: # type: ignore[type-arg]
|
||||
"""
|
||||
Loads the model from the PT2 package.
|
||||
|
||||
@ -248,10 +248,13 @@ def aoti_load_package(path: FileLike) -> Any: # type: ignore[type-arg]
|
||||
|
||||
Args:
|
||||
path: Path to the .pt2 package
|
||||
run_single_threaded (bool): Whether the model should be run without
|
||||
thread synchronization logic. This is useful to avoid conflicts with
|
||||
CUDAGraphs.
|
||||
"""
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
return load_package(path)
|
||||
return load_package(path, run_single_threaded=run_single_threaded)
|
||||
|
||||
|
||||
def aot_compile(
|
||||
|
@ -117,6 +117,33 @@ AOTIRuntimeError AOTInductorModelContainerRun(
|
||||
})
|
||||
}
|
||||
|
||||
AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded(
|
||||
AOTInductorModelContainerHandle container_handle,
|
||||
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
|
||||
// are stolen; the array itself is borrowed
|
||||
size_t num_inputs,
|
||||
AtenTensorHandle*
|
||||
output_handles, // array for writing output AtenTensorHandle; handles
|
||||
// will be stolen by the caller; the array itself is
|
||||
// borrowed
|
||||
size_t num_outputs,
|
||||
AOTInductorStreamHandle stream_handle,
|
||||
AOTIProxyExecutorHandle proxy_executor_handle) {
|
||||
auto* container =
|
||||
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
|
||||
container_handle);
|
||||
AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs");
|
||||
AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs");
|
||||
|
||||
auto stream =
|
||||
reinterpret_cast<torch::aot_inductor::DeviceStreamType>(stream_handle);
|
||||
CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
AOTINoGradGuard guard;
|
||||
container->run_single_threaded(
|
||||
input_handles, output_handles, stream, proxy_executor_handle);
|
||||
})
|
||||
}
|
||||
|
||||
AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
|
||||
AOTInductorModelContainerHandle container_handle,
|
||||
size_t* num_constants) {
|
||||
|
@ -274,7 +274,9 @@ class AOTICompiledModel:
|
||||
return AOTICompiledModel(self.loader) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
|
||||
def load_package(
|
||||
path: FileLike, model_name: str = "model", run_single_threaded: bool = False
|
||||
) -> AOTICompiledModel: # type: ignore[type-arg]
|
||||
assert (
|
||||
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
|
||||
) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), (
|
||||
@ -288,9 +290,13 @@ def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel
|
||||
f.write(path.read())
|
||||
path.seek(0)
|
||||
log.debug("Writing buffer to tmp file located at %s.", f.name)
|
||||
loader = torch._C._aoti.AOTIModelPackageLoader(f.name, model_name) # type: ignore[call-arg]
|
||||
loader = torch._C._aoti.AOTIModelPackageLoader(
|
||||
f.name, model_name, run_single_threaded
|
||||
) # type: ignore[call-arg]
|
||||
return AOTICompiledModel(loader)
|
||||
|
||||
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
|
||||
loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg]
|
||||
loader = torch._C._aoti.AOTIModelPackageLoader(
|
||||
path, model_name, run_single_threaded
|
||||
) # type: ignore[call-arg]
|
||||
return AOTICompiledModel(loader)
|
||||
|
@ -437,7 +437,8 @@ std::shared_ptr<AOTIModelContainerRunner> AOTIPythonKernelHolder::
|
||||
return std::make_shared<AOTIModelContainerRunnerCpu>(so_path);
|
||||
} else {
|
||||
auto aoti_model_runer_fn = registered_aoti_runner[device_name];
|
||||
return aoti_model_runer_fn(so_path, 1, device_name, "");
|
||||
return aoti_model_runer_fn(
|
||||
so_path, 1, device_name, "", /*run_single_threaded=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -339,12 +339,15 @@ void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) {
|
||||
}
|
||||
|
||||
AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||
const std::string& model_package_path)
|
||||
: AOTIModelPackageLoader(model_package_path, "model") {}
|
||||
const std::string& model_package_path,
|
||||
const bool run_single_threaded = false)
|
||||
: AOTIModelPackageLoader(model_package_path, "model", run_single_threaded) {
|
||||
}
|
||||
|
||||
AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||
const std::string& model_package_path,
|
||||
const std::string& model_name = "model") {
|
||||
const std::string& model_name = "model",
|
||||
const bool run_single_threaded = false) {
|
||||
// Extract all files within the zipfile to a temporary directory
|
||||
mz_zip_archive zip_archive;
|
||||
memset(&zip_archive, 0, sizeof(zip_archive));
|
||||
@ -457,7 +460,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||
}
|
||||
|
||||
std::string cubin_dir = temp_dir_ + k_separator + model_directory;
|
||||
runner_ = registered_aoti_runner[device](so_path, 1, device, cubin_dir);
|
||||
runner_ = registered_aoti_runner[device](
|
||||
so_path, 1, device, cubin_dir, run_single_threaded);
|
||||
}
|
||||
|
||||
AOTIModelPackageLoader::~AOTIModelPackageLoader() {
|
||||
|
@ -7,10 +7,13 @@
|
||||
namespace torch::inductor {
|
||||
class TORCH_API AOTIModelPackageLoader {
|
||||
public:
|
||||
AOTIModelPackageLoader(const std::string& model_package_path);
|
||||
AOTIModelPackageLoader(
|
||||
const std::string& model_package_path,
|
||||
const std::string& model_name);
|
||||
const bool run_single_threaded);
|
||||
AOTIModelPackageLoader(
|
||||
const std::string& model_package_path,
|
||||
const std::string& model_name,
|
||||
const bool run_single_threaded);
|
||||
~AOTIModelPackageLoader();
|
||||
|
||||
AOTIModelContainerRunner* get_runner();
|
||||
|
@ -13,13 +13,19 @@ namespace torch::inductor {
|
||||
|
||||
class AOTIModelPackageLoaderPybind : public AOTIModelPackageLoader {
|
||||
public:
|
||||
AOTIModelPackageLoaderPybind(const std::string& model_package_path)
|
||||
: AOTIModelPackageLoader(model_package_path) {}
|
||||
AOTIModelPackageLoaderPybind(
|
||||
const std::string& model_package_path,
|
||||
const bool run_single_threaded)
|
||||
: AOTIModelPackageLoader(model_package_path, run_single_threaded) {}
|
||||
|
||||
AOTIModelPackageLoaderPybind(
|
||||
const std::string& model_package_path,
|
||||
const std::string& model_name)
|
||||
: AOTIModelPackageLoader(model_package_path, model_name) {}
|
||||
const std::string& model_name,
|
||||
const bool run_single_threaded)
|
||||
: AOTIModelPackageLoader(
|
||||
model_package_path,
|
||||
model_name,
|
||||
run_single_threaded) {}
|
||||
|
||||
py::list boxed_run(py::list& inputs, void* stream_handle = nullptr) {
|
||||
std::vector<at::Tensor> input_tensors;
|
||||
@ -47,8 +53,8 @@ void initAOTIPackageBindings(PyObject* module) {
|
||||
auto m = rootModule.def_submodule("_aoti");
|
||||
|
||||
py::class_<AOTIModelPackageLoaderPybind>(m, "AOTIModelPackageLoader")
|
||||
.def(py::init<const std::string&, const std::string&>())
|
||||
.def(py::init<const std::string&>())
|
||||
.def(py::init<const std::string&, const std::string&, const bool>())
|
||||
.def(py::init<const std::string&, const bool>())
|
||||
.def("get_metadata", &AOTIModelPackageLoaderPybind::get_metadata)
|
||||
.def(
|
||||
"run",
|
||||
|
@ -29,7 +29,8 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& cubin_dir) {
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded) {
|
||||
model_so_ = std::make_unique<at::DynamicLibrary>(model_so_path.c_str());
|
||||
TORCH_CHECK(model_so_, "Failed to load model: ", model_so_path);
|
||||
create_func_ = reinterpret_cast<decltype(create_func_)>(
|
||||
@ -38,8 +39,9 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
|
||||
model_so_->sym("AOTInductorModelContainerDelete"));
|
||||
get_num_outputs_func_ = reinterpret_cast<decltype(get_num_outputs_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetNumOutputs"));
|
||||
run_func_ = reinterpret_cast<decltype(run_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerRun"));
|
||||
run_func_ = reinterpret_cast<decltype(run_func_)>(model_so_->sym(
|
||||
run_single_threaded ? "AOTInductorModelContainerRunSingleThreaded"
|
||||
: "AOTInductorModelContainerRun"));
|
||||
get_num_constants_func_ = reinterpret_cast<decltype(get_num_constants_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetNumConstants"));
|
||||
get_constant_name_func_ = reinterpret_cast<decltype(get_constant_name_func_)>(
|
||||
|
@ -58,7 +58,8 @@ class TORCH_API AOTIModelContainerRunner {
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& cubin_dir);
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded);
|
||||
|
||||
virtual std::vector<at::Tensor> run_impl(
|
||||
std::vector<AtenTensorHandle>& input_handles,
|
||||
@ -100,7 +101,8 @@ using CreateAOTIModelRunnerFunc = std::unique_ptr<AOTIModelContainerRunner> (*)(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& bin_dir);
|
||||
const std::string& bin_dir,
|
||||
const bool run_single_threaded);
|
||||
|
||||
// Return a global map "device name" -> "aoti model runner create function" for
|
||||
// all registered in AOTI external backends
|
||||
|
@ -7,8 +7,14 @@ namespace torch::inductor {
|
||||
// We provide NO BC guarantee for these APIs
|
||||
AOTIModelContainerRunnerCpu::AOTIModelContainerRunnerCpu(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models)
|
||||
: AOTIModelContainerRunner(model_so_path, num_models, "cpu", "") {}
|
||||
size_t num_models,
|
||||
bool run_single_threaded)
|
||||
: AOTIModelContainerRunner(
|
||||
model_so_path,
|
||||
num_models,
|
||||
"cpu",
|
||||
"",
|
||||
run_single_threaded) {}
|
||||
|
||||
AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() = default;
|
||||
|
||||
@ -17,12 +23,13 @@ std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_cpu(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& cubin_dir) {
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded) {
|
||||
if (device_str != "cpu") {
|
||||
throw std::runtime_error("Incorrect device passed to aoti_runner_cpu");
|
||||
}
|
||||
return std::make_unique<AOTIModelContainerRunnerCpu>(
|
||||
model_so_path, num_models);
|
||||
model_so_path, num_models, run_single_threaded);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -8,7 +8,8 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner {
|
||||
public:
|
||||
AOTIModelContainerRunnerCpu(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models = 1);
|
||||
size_t num_models = 1,
|
||||
const bool run_single_threaded = false);
|
||||
|
||||
~AOTIModelContainerRunnerCpu() override;
|
||||
};
|
||||
|
@ -7,12 +7,14 @@ AOTIModelContainerRunnerCuda::AOTIModelContainerRunnerCuda(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& cubin_dir)
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded)
|
||||
: AOTIModelContainerRunner(
|
||||
model_so_path,
|
||||
num_models,
|
||||
device_str,
|
||||
cubin_dir) {}
|
||||
cubin_dir,
|
||||
run_single_threaded) {}
|
||||
|
||||
AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() = default;
|
||||
|
||||
@ -37,9 +39,10 @@ std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_cuda(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& cubin_dir) {
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded) {
|
||||
return std::make_unique<AOTIModelContainerRunnerCuda>(
|
||||
model_so_path, num_models, device_str, cubin_dir);
|
||||
model_so_path, num_models, device_str, cubin_dir, run_single_threaded);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -17,7 +17,8 @@ class TORCH_CUDA_CPP_API AOTIModelContainerRunnerCuda
|
||||
const std::string& model_so_path,
|
||||
size_t num_models = 1,
|
||||
const std::string& device_str = "cuda",
|
||||
const std::string& cubin_dir = "");
|
||||
const std::string& cubin_dir = "",
|
||||
const bool run_single_threaded = false);
|
||||
|
||||
~AOTIModelContainerRunnerCuda() override;
|
||||
|
||||
|
@ -7,12 +7,14 @@ AOTIModelContainerRunnerXpu::AOTIModelContainerRunnerXpu(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& kernel_bin_dir)
|
||||
const std::string& kernel_bin_dir,
|
||||
const bool run_single_threaded)
|
||||
: AOTIModelContainerRunner(
|
||||
model_so_path,
|
||||
num_models,
|
||||
device_str,
|
||||
kernel_bin_dir) {}
|
||||
kernel_bin_dir,
|
||||
run_single_threaded) {}
|
||||
|
||||
AOTIModelContainerRunnerXpu::~AOTIModelContainerRunnerXpu() = default;
|
||||
|
||||
@ -37,9 +39,14 @@ std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_xpu(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& kernel_bin_dir) {
|
||||
const std::string& kernel_bin_dir,
|
||||
const bool run_single_threaded) {
|
||||
return std::make_unique<AOTIModelContainerRunnerXpu>(
|
||||
model_so_path, num_models, device_str, kernel_bin_dir);
|
||||
model_so_path,
|
||||
num_models,
|
||||
device_str,
|
||||
kernel_bin_dir,
|
||||
run_single_threaded);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -19,7 +19,8 @@ class C10_EXPORT AOTIModelContainerRunnerXpu : public AOTIModelContainerRunner {
|
||||
const std::string& model_so_path,
|
||||
size_t num_models = 1,
|
||||
const std::string& device_str = "xpu",
|
||||
const std::string& kernel_bin_dir = "");
|
||||
const std::string& kernel_bin_dir = "",
|
||||
const bool run_single_threaded = false);
|
||||
|
||||
~AOTIModelContainerRunnerXpu() override;
|
||||
|
||||
|
@ -58,6 +58,20 @@ AOTIRuntimeError AOTInductorModelContainerRun(
|
||||
AOTInductorStreamHandle stream_handle,
|
||||
AOTIProxyExecutorHandle proxy_executor_handle);
|
||||
|
||||
// Single-threaded variant of previous.
|
||||
AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded(
|
||||
AOTInductorModelContainerHandle container_handle,
|
||||
AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
|
||||
// are stolen; the array itself is borrowed
|
||||
size_t num_inputs,
|
||||
AtenTensorHandle*
|
||||
output_handles, // array for writing output AtenTensorHandle; handles
|
||||
// will be stolen by the caller; the array itself is
|
||||
// borrowed
|
||||
size_t num_outputs,
|
||||
AOTInductorStreamHandle stream_handle,
|
||||
AOTIProxyExecutorHandle proxy_executor_handle);
|
||||
|
||||
// Retrieves the number of constants for the model.
|
||||
AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
|
||||
AOTInductorModelContainerHandle container_handle,
|
||||
|
@ -231,6 +231,24 @@ class AOTInductorModelBase {
|
||||
#endif // USE_CUDA
|
||||
}
|
||||
|
||||
// Non-thread-aware variant of run(). Obviously unsafe to use in a threaded
|
||||
// environment :)
|
||||
void run_single_threaded(
|
||||
AtenTensorHandle*
|
||||
input_handles, // array of input AtenTensorHandle; handles
|
||||
// are stolen; the array itself is borrowed
|
||||
AtenTensorHandle*
|
||||
output_handles, // array for writing output AtenTensorHandle; handles
|
||||
// will be stolen by the caller; the array itself is
|
||||
// borrowed
|
||||
DeviceStreamType stream,
|
||||
AOTIProxyExecutorHandle proxy_executor) {
|
||||
// don't bother with any of the run_finished stuff; this is unsafe to call
|
||||
// in a threaded context
|
||||
auto* model = static_cast<Model*>(this);
|
||||
model->run_impl(input_handles, output_handles, stream, proxy_executor);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, AtenTensorHandle> run_const_fold(
|
||||
DeviceStreamType stream,
|
||||
AOTIProxyExecutorHandle proxy_executor,
|
||||
|
@ -112,6 +112,34 @@ class AOTInductorModelContainer {
|
||||
pending_models_available_.notify_one();
|
||||
}
|
||||
|
||||
// Non-thread-aware variant of run(). Obviously unsafe to use in a threaded
|
||||
// environment :)
|
||||
void run_single_threaded(
|
||||
AtenTensorHandle*
|
||||
input_handles, // array of input AtenTensorHandle; handles
|
||||
// are stolen; the array itself is borrowed
|
||||
AtenTensorHandle*
|
||||
output_handles, // array for writing output AtenTensorHandle; handles
|
||||
// will be stolen by the caller; the array itself is
|
||||
// borrowed
|
||||
DeviceStreamType stream,
|
||||
AOTIProxyExecutorHandle proxy_executor) {
|
||||
auto* model = available_models_[0];
|
||||
|
||||
if (!constant_folded_) {
|
||||
auto folded_const_map = model->run_const_fold(
|
||||
stream, proxy_executor, /* initialization = */ true);
|
||||
update_constant_buffer(
|
||||
std::move(folded_const_map),
|
||||
/* use_inactive = */ false,
|
||||
/* validate_full_update = */ false);
|
||||
constant_folded_ = true;
|
||||
}
|
||||
|
||||
model->run_single_threaded(
|
||||
input_handles, output_handles, stream, proxy_executor);
|
||||
}
|
||||
|
||||
size_t num_constants() const {
|
||||
if (this->num_models() == 0) {
|
||||
throw std::runtime_error("No available models in container!");
|
||||
|
Reference in New Issue
Block a user