Revert "Remove MemPoolContext (#154042)"

This reverts commit 3b38989b5f8f918cf1ad38bdade059608544af4b.

Reverted https://github.com/pytorch/pytorch/pull/154042 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/154042#issuecomment-2921401100))
This commit is contained in:
PyTorch MergeBot
2025-05-30 06:53:35 +00:00
parent 0fdd568b78
commit d173ba5a75
14 changed files with 243 additions and 139 deletions

View File

@ -721,24 +721,8 @@ CapturedTraceback* getFromContext(
"attempting to gather stack context from the wrong StackContext type.");
}
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
c10::cuda::MempoolId_t mempool_id = {0, 0};
if (arg && arg != Py_None) {
TORCH_CHECK(PyTuple_Check(arg), "mempool_id must be a tuple");
Py_ssize_t size = PyTuple_Size(arg);
TORCH_CHECK(size == 2, "mempool_id must be a tuple of 2 integers");
auto id1 = THPObjectPtr(PyTuple_GetItem(arg, 0));
auto id2 = THPObjectPtr(PyTuple_GetItem(arg, 1));
TORCH_CHECK(
THPUtils_checkLong(id1) && THPUtils_checkLong(id2),
"mempool_id elements must be integers");
mempool_id = c10::cuda::MempoolId_t(
static_cast<int64_t>(THPUtils_unpackLong(id1)),
static_cast<int64_t>(THPUtils_unpackLong(id2)));
}
using c10::cuda::CUDACachingAllocator::BlockInfo;
using c10::cuda::CUDACachingAllocator::SegmentInfo;
@ -818,7 +802,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
return segmentDict;
};
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(mempool_id);
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
py::list segments;
@ -2027,7 +2011,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_resetPeakMemoryStats,
METH_O,
nullptr},
{"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_O, nullptr},
{"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr},
{"_cuda_attach_out_of_memory_observer",
THCPModule_attachOutOfMemoryObserver,
METH_O,