Support with statement on torch.Stream (#140138)

# Motivation
We propose to support Python with statement on `torch.Stream`. This is a benefit for all accelerators when writing device-agnostic code. The device-specific stream will also be supported because they are generally derived from `torch.Stream`.

With this PR, we can do like this
```python
s1= torch.Stream()
# Set s1 to the current stream
torch.accelerator.set_stream(s1)
with torch.Stream() as s2:
    # Inside with statement, we set s2 to the current stream
    assert torch.accelerator.current_stream() == s2
# Here the current stream should be s1
assert torch.accelerator.current_stream() == s1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140138
Approved by: https://github.com/albanD
This commit is contained in:
Yu, Guangye
2025-01-09 23:41:36 +00:00
committed by PyTorch MergeBot
parent 04cb19d225
commit 6de110b862
8 changed files with 166 additions and 5 deletions

View File

@ -79,6 +79,29 @@ class TestAccelerator(TestCase):
): ):
torch.accelerator.current_stream(other_device) torch.accelerator.current_stream(other_device)
def test_stream_context_manager(self):
prev_stream = torch.accelerator.current_stream()
with torch.Stream() as s:
self.assertEqual(torch.accelerator.current_stream(), s)
self.assertEqual(torch.accelerator.current_stream(), prev_stream)
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
def test_multi_device_stream_context_manager(self):
src_device = 0
dst_device = 1
torch.accelerator.set_device_index(src_device)
src_prev_stream = torch.accelerator.current_stream()
dst_prev_stream = torch.accelerator.current_stream(dst_device)
with torch.Stream(dst_device) as dst_stream:
self.assertEqual(torch.accelerator.current_device_index(), dst_device)
self.assertEqual(torch.accelerator.current_stream(), dst_stream)
self.assertEqual(
torch.accelerator.current_stream(src_device), src_prev_stream
)
self.assertEqual(torch.accelerator.current_device_index(), src_device)
self.assertEqual(torch.accelerator.current_stream(), src_prev_stream)
self.assertEqual(torch.accelerator.current_stream(dst_device), dst_prev_stream)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -841,6 +841,27 @@ class TestCuda(TestCase):
self.assertNotEqual(try_realloc.data_ptr(), data_ptr) self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
def test_stream_context_manager(self):
prev_stream = torch.cuda.current_stream()
with torch.cuda.Stream() as stream:
self.assertEqual(stream, torch.cuda.current_stream())
self.assertEqual(prev_stream, torch.cuda.current_stream())
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_multi_device_stream_context_manager(self):
src_device = 0
dst_device = 1
torch.cuda.set_device(src_device)
src_prev_stream = torch.cuda.current_stream(src_device)
dst_prev_stream = torch.cuda.current_stream(dst_device)
with torch.cuda.Stream(dst_device) as dst_stream:
self.assertEqual(dst_device, torch.cuda.current_device())
self.assertEqual(dst_stream, torch.cuda.current_stream())
self.assertEqual(src_prev_stream, torch.cuda.current_stream(src_device))
self.assertEqual(src_device, torch.cuda.current_device())
self.assertEqual(src_prev_stream, torch.cuda.current_stream())
self.assertEqual(dst_prev_stream, torch.cuda.current_stream(dst_device))
def test_noncontiguous_pinned_memory(self): def test_noncontiguous_pinned_memory(self):
# See issue #3266 # See issue #3266
x = torch.arange(0, 10).view((2, 5)) x = torch.arange(0, 10).view((2, 5))

View File

@ -309,6 +309,27 @@ print(torch.xpu.device_count())
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"): with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
torch.accelerator.current_stream(torch.accelerator.device_count()) torch.accelerator.current_stream(torch.accelerator.device_count())
def test_stream_context_manager(self):
prev_stream = torch.xpu.current_stream()
with torch.xpu.Stream() as stream:
self.assertEqual(stream, torch.xpu.current_stream())
self.assertEqual(prev_stream, torch.xpu.current_stream())
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
def test_multi_device_stream_context_manager(self):
src_device = 0
dst_device = 1
torch.xpu.set_device(src_device)
src_prev_stream = torch.xpu.current_stream(src_device)
dst_prev_stream = torch.xpu.current_stream(dst_device)
with torch.xpu.Stream(dst_device) as dst_stream:
self.assertEqual(dst_device, torch.xpu.current_device())
self.assertEqual(dst_stream, torch.xpu.current_stream())
self.assertEqual(src_prev_stream, torch.xpu.current_stream(src_device))
self.assertEqual(src_device, torch.xpu.current_device())
self.assertEqual(src_prev_stream, torch.xpu.current_stream())
self.assertEqual(dst_prev_stream, torch.xpu.current_stream(dst_device))
def test_generator(self): def test_generator(self):
torch.manual_seed(2024) torch.manual_seed(2024)
g_state0 = torch.xpu.get_rng_state() g_state0 = torch.xpu.get_rng_state()

View File

@ -13215,7 +13215,8 @@ Stream(device, *, priority) -> Stream
An in-order queue of executing the respective tasks asynchronously in first in first out (FIFO) order. An in-order queue of executing the respective tasks asynchronously in first in first out (FIFO) order.
It can control or synchronize the execution of other Stream or block the current host thread to ensure It can control or synchronize the execution of other Stream or block the current host thread to ensure
the correct task sequencing. the correct task sequencing. It supports with statement as a context manager to ensure the operators
within the with block are running on the corresponding stream.
See in-depth description of the CUDA behavior at :ref:`cuda-semantics` for details See in-depth description of the CUDA behavior at :ref:`cuda-semantics` for details
on the exact semantic that applies to all devices. on the exact semantic that applies to all devices.
@ -13232,7 +13233,10 @@ Returns:
Example:: Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> s_cuda = torch.Stream(device='cuda') >>> with torch.Stream(device='cuda') as s_cuda:
>>> a = torch.randn(10, 5, device='cuda')
>>> b = torch.randn(5, 10, device='cuda')
>>> c = torch.mm(a, b)
""", """,
) )

View File

@ -94,6 +94,7 @@ static PyObject* THPStream_pynew(
// NOLINTNEXTLINE(bugprone-signed-char-misuse) // NOLINTNEXTLINE(bugprone-signed-char-misuse)
self->device_index = static_cast<int64_t>(stream_opt->device_index()); self->device_index = static_cast<int64_t>(stream_opt->device_index());
self->device_type = static_cast<int64_t>(stream_opt->device_type()); self->device_type = static_cast<int64_t>(stream_opt->device_type());
self->context = nullptr;
return (PyObject*)ptr.release(); return (PyObject*)ptr.release();
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
@ -112,6 +113,7 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
// NOLINTNEXTLINE(bugprone-signed-char-misuse) // NOLINTNEXTLINE(bugprone-signed-char-misuse)
self->device_index = static_cast<int64_t>(stream.device_index()); self->device_index = static_cast<int64_t>(stream.device_index());
self->device_type = static_cast<int64_t>(stream.device_type()); self->device_type = static_cast<int64_t>(stream.device_type());
self->context = nullptr;
return ptr.release(); return ptr.release();
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -256,6 +258,89 @@ static PyObject* THPStream_eq(THPStream* self, THPStream* other) {
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStream_enter(PyObject* _self, PyObject* unused) {
HANDLE_TH_ERRORS
auto self = (THPStream*)_self;
c10::DeviceType stream_device_type =
static_cast<c10::DeviceType>(self->device_type);
// No operation is performed if the stream does not belong to an accelerator.
if (C10_UNLIKELY(!at::accelerator::isAccelerator(stream_device_type))) {
Py_INCREF(_self);
return _self;
}
c10::DeviceIndex cur_device_idx = at::accelerator::getDeviceIndex();
c10::DeviceIndex stream_device_idx =
static_cast<c10::DeviceIndex>(self->device_index);
// If the stream is not on the current device, switch the current device to
// the device of the stream.
if (stream_device_idx != cur_device_idx) {
at::accelerator::setDeviceIndex(stream_device_idx);
}
c10::Stream cur_stream = at::accelerator::getCurrentStream(stream_device_idx);
at::accelerator::setCurrentStream(c10::Stream::unpack3(
self->stream_id, stream_device_idx, stream_device_type));
// Save the current device index and previous stream to the context.
auto ctx_device_index =
THPObjectPtr(THPUtils_packDeviceIndex(cur_device_idx));
auto ctx_stream = THPObjectPtr(THPStream_Wrap(cur_stream));
TORCH_CHECK(!(self->context), "Stream's context should not be initialized.");
auto dict = THPObjectPtr(PyDict_New());
if (!dict) {
throw python_error();
}
self->context = dict.release();
if (PyDict_SetItemString(
self->context, "_ctx_device_index", ctx_device_index.get()) < 0) {
throw python_error();
}
if (PyDict_SetItemString(self->context, "_ctx_stream", ctx_stream.get()) <
0) {
throw python_error();
}
Py_INCREF(_self);
return _self;
END_HANDLE_TH_ERRORS
}
static PyObject* THPStream_exit(PyObject* _self, PyObject* unused) {
HANDLE_TH_ERRORS
auto self = (THPStream*)_self;
// No operation is performed if the stream does not belong to an accelerator.
if (C10_UNLIKELY(!at::accelerator::isAccelerator(
static_cast<c10::DeviceType>(self->device_type)))) {
Py_RETURN_NONE;
}
PyObject* py_stream = nullptr;
if (PyDict_GetItemStringRef(self->context, "_ctx_stream", &py_stream) < 0) {
throw python_error();
}
auto ctx_stream = THPObjectPtr(py_stream);
PyObject* py_device_index = nullptr;
if (PyDict_GetItemStringRef(
self->context, "_ctx_device_index", &py_device_index) < 0) {
throw python_error();
}
auto ctx_device_index = THPObjectPtr(py_device_index);
TORCH_INTERNAL_ASSERT(
ctx_stream.get(), "ctx_stream should be present on the context dict.");
auto prev_stream = (THPStream*)(ctx_stream.get());
TORCH_INTERNAL_ASSERT(
ctx_device_index.get(),
"ctx_device_index should be present on the context dict.");
auto prev_device_index = THPUtils_unpackDeviceIndex(ctx_device_index.get());
at::accelerator::setCurrentStream(c10::Stream::unpack3(
prev_stream->stream_id,
static_cast<c10::DeviceIndex>(prev_stream->device_index),
static_cast<c10::DeviceType>(prev_stream->device_type)));
// Reset the current device to the previous device if they differ.
if (static_cast<c10::DeviceIndex>(self->device_index) != prev_device_index) {
at::accelerator::setDeviceIndex(prev_device_index);
}
Py_CLEAR(self->context);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPStream_ne(THPStream* self, THPStream* other) { static PyObject* THPStream_ne(THPStream* self, THPStream* other) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
return PyBool_FromLong( return PyBool_FromLong(
@ -321,6 +406,8 @@ static const std::initializer_list<PyMethodDef> THPStream_methods = {
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
nullptr}, nullptr},
{"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr}, {"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
{"__enter__", THPStream_enter, METH_NOARGS, nullptr},
{"__exit__", THPStream_exit, METH_VARARGS, nullptr},
{nullptr}}; {nullptr}};
static PyTypeObject THPStreamType = { static PyTypeObject THPStreamType = {

View File

@ -10,6 +10,8 @@ struct THPStream {
int64_t stream_id; int64_t stream_id;
int64_t device_type; int64_t device_type;
int64_t device_index; int64_t device_index;
// Used to switch stream context management, initialized lazily.
PyObject* context;
}; };
extern TORCH_API PyTypeObject* THPStreamClass; extern TORCH_API PyTypeObject* THPStreamClass;

View File

@ -15,8 +15,9 @@ class Stream(torch._C._CudaStreamBase):
r"""Wrapper around a CUDA stream. r"""Wrapper around a CUDA stream.
A CUDA stream is a linear sequence of execution that belongs to a specific A CUDA stream is a linear sequence of execution that belongs to a specific
device, independent from other streams. See :ref:`cuda-semantics` for device, independent from other streams. It supports with statement as a
details. context manager to ensure the operators within the with block are running
on the corresponding stream. See :ref:`cuda-semantics` for details.
Args: Args:
device(torch.device or int, optional): a device on which to allocate device(torch.device or int, optional): a device on which to allocate

View File

@ -15,7 +15,9 @@ class Stream(torch._C._XpuStreamBase):
r"""Wrapper around a XPU stream. r"""Wrapper around a XPU stream.
A XPU stream is a linear sequence of execution that belongs to a specific A XPU stream is a linear sequence of execution that belongs to a specific
device, independent from other streams. device, independent from other streams. It supports with statement as a
context manager to ensure the operators within the with block are running
on the corresponding stream.
Args: Args:
device(torch.device or int, optional): a device on which to allocate device(torch.device or int, optional): a device on which to allocate