mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This reverts commit c7515da7b00de40942c83dc5856b6daec727e280. Reverted https://github.com/pytorch/pytorch/pull/140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](https://github.com/pytorch/pytorch/pull/140979#issuecomment-2657361940))
92 lines
3.4 KiB
C++
92 lines
3.4 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <pybind11/chrono.h>
|
|
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
#include <ATen/cuda/CUDAGraph.h>
|
|
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
|
|
|
// Cargo culted partially from csrc/distributed/c10d/init.cpp
|
|
// and partially from csrc/cuda/Stream.cpp.
|
|
// THCPStream_init is also declared at global scope.
|
|
|
|
// Because THCPGraph_init is forward declared in the only consumer
|
|
// (csrc/Module.cpp) I don't think we need a Graph.h.
|
|
|
|
template <typename T>
|
|
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
|
|
|
void THCPGraph_init(PyObject* module) {
|
|
// Pybind11 patch notes say "py::module_" is more up-to-date syntax,
|
|
// but CI linter and some builds prefer "module".
|
|
auto torch_C_m = py::handle(module).cast<py::module>();
|
|
|
|
torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
|
|
|
|
shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
|
|
.def(py::init<>())
|
|
.def(
|
|
"capture_begin",
|
|
[](::at::cuda::CUDAGraph& self,
|
|
std::optional<c10::cuda::MempoolId_t> pool_opt,
|
|
const std::string& capture_error_mode) {
|
|
cudaStreamCaptureMode capture_mode{};
|
|
c10::cuda::MempoolId_t pool = pool_opt.has_value()
|
|
? pool_opt.value()
|
|
: c10::cuda::MempoolId_t{0, 0};
|
|
if (capture_error_mode == "global") {
|
|
capture_mode = cudaStreamCaptureModeGlobal;
|
|
} else if (capture_error_mode == "thread_local") {
|
|
capture_mode = cudaStreamCaptureModeThreadLocal;
|
|
} else if (capture_error_mode == "relaxed") {
|
|
capture_mode = cudaStreamCaptureModeRelaxed;
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
|
|
capture_error_mode);
|
|
}
|
|
return self.capture_begin(pool, capture_mode);
|
|
},
|
|
py::arg("pool"),
|
|
py::arg("capture_error_mode"),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"capture_end",
|
|
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
|
|
.def(
|
|
"register_generator_state",
|
|
[](::at::cuda::CUDAGraph& self, py::handle raw_generator) {
|
|
auto generator = THPGenerator_Unwrap(raw_generator.ptr());
|
|
// We've unwrapped Python object to C++ object,
|
|
// so we could release GIL before calling into C++
|
|
py::gil_scoped_release release;
|
|
return self.register_generator_state(generator);
|
|
},
|
|
py::arg("generator"))
|
|
.def(
|
|
"replay",
|
|
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))
|
|
.def(
|
|
"reset",
|
|
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset))
|
|
.def(
|
|
"pool",
|
|
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool))
|
|
.def(
|
|
"debug_dump",
|
|
torch::wrap_pybind_function_no_gil(
|
|
&::at::cuda::CUDAGraph::debug_dump))
|
|
.def(
|
|
"enable_debug_mode",
|
|
torch::wrap_pybind_function_no_gil(
|
|
&::at::cuda::CUDAGraph::enable_debug_mode))
|
|
.def(
|
|
"debug_dump",
|
|
torch::wrap_pybind_function_no_gil(
|
|
&::at::cuda::CUDAGraph::debug_dump),
|
|
py::arg("debug_path"));
|
|
}
|