mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Support with statement on torch.Stream (#140138)
# Motivation We propose to support Python with statement on `torch.Stream`. This is a benefit for all accelerators when writing device-agnostic code. The device-specific stream will also be supported because they are generally derived from `torch.Stream`. With this PR, we can do like this ```python s1= torch.Stream() # Set s1 to the current stream torch.accelerator.set_stream(s1) with torch.Stream() as s2: # Inside with statement, we set s2 to the current stream assert torch.accelerator.current_stream() == s2 # Here the current stream should be s1 assert torch.accelerator.current_stream() == s1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140138 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
04cb19d225
commit
6de110b862
@ -79,6 +79,29 @@ class TestAccelerator(TestCase):
|
|||||||
):
|
):
|
||||||
torch.accelerator.current_stream(other_device)
|
torch.accelerator.current_stream(other_device)
|
||||||
|
|
||||||
|
def test_stream_context_manager(self):
|
||||||
|
prev_stream = torch.accelerator.current_stream()
|
||||||
|
with torch.Stream() as s:
|
||||||
|
self.assertEqual(torch.accelerator.current_stream(), s)
|
||||||
|
self.assertEqual(torch.accelerator.current_stream(), prev_stream)
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
|
||||||
|
def test_multi_device_stream_context_manager(self):
|
||||||
|
src_device = 0
|
||||||
|
dst_device = 1
|
||||||
|
torch.accelerator.set_device_index(src_device)
|
||||||
|
src_prev_stream = torch.accelerator.current_stream()
|
||||||
|
dst_prev_stream = torch.accelerator.current_stream(dst_device)
|
||||||
|
with torch.Stream(dst_device) as dst_stream:
|
||||||
|
self.assertEqual(torch.accelerator.current_device_index(), dst_device)
|
||||||
|
self.assertEqual(torch.accelerator.current_stream(), dst_stream)
|
||||||
|
self.assertEqual(
|
||||||
|
torch.accelerator.current_stream(src_device), src_prev_stream
|
||||||
|
)
|
||||||
|
self.assertEqual(torch.accelerator.current_device_index(), src_device)
|
||||||
|
self.assertEqual(torch.accelerator.current_stream(), src_prev_stream)
|
||||||
|
self.assertEqual(torch.accelerator.current_stream(dst_device), dst_prev_stream)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -841,6 +841,27 @@ class TestCuda(TestCase):
|
|||||||
|
|
||||||
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
|
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
|
||||||
|
|
||||||
|
def test_stream_context_manager(self):
|
||||||
|
prev_stream = torch.cuda.current_stream()
|
||||||
|
with torch.cuda.Stream() as stream:
|
||||||
|
self.assertEqual(stream, torch.cuda.current_stream())
|
||||||
|
self.assertEqual(prev_stream, torch.cuda.current_stream())
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||||
|
def test_multi_device_stream_context_manager(self):
|
||||||
|
src_device = 0
|
||||||
|
dst_device = 1
|
||||||
|
torch.cuda.set_device(src_device)
|
||||||
|
src_prev_stream = torch.cuda.current_stream(src_device)
|
||||||
|
dst_prev_stream = torch.cuda.current_stream(dst_device)
|
||||||
|
with torch.cuda.Stream(dst_device) as dst_stream:
|
||||||
|
self.assertEqual(dst_device, torch.cuda.current_device())
|
||||||
|
self.assertEqual(dst_stream, torch.cuda.current_stream())
|
||||||
|
self.assertEqual(src_prev_stream, torch.cuda.current_stream(src_device))
|
||||||
|
self.assertEqual(src_device, torch.cuda.current_device())
|
||||||
|
self.assertEqual(src_prev_stream, torch.cuda.current_stream())
|
||||||
|
self.assertEqual(dst_prev_stream, torch.cuda.current_stream(dst_device))
|
||||||
|
|
||||||
def test_noncontiguous_pinned_memory(self):
|
def test_noncontiguous_pinned_memory(self):
|
||||||
# See issue #3266
|
# See issue #3266
|
||||||
x = torch.arange(0, 10).view((2, 5))
|
x = torch.arange(0, 10).view((2, 5))
|
||||||
|
@ -309,6 +309,27 @@ print(torch.xpu.device_count())
|
|||||||
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
||||||
torch.accelerator.current_stream(torch.accelerator.device_count())
|
torch.accelerator.current_stream(torch.accelerator.device_count())
|
||||||
|
|
||||||
|
def test_stream_context_manager(self):
|
||||||
|
prev_stream = torch.xpu.current_stream()
|
||||||
|
with torch.xpu.Stream() as stream:
|
||||||
|
self.assertEqual(stream, torch.xpu.current_stream())
|
||||||
|
self.assertEqual(prev_stream, torch.xpu.current_stream())
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
|
||||||
|
def test_multi_device_stream_context_manager(self):
|
||||||
|
src_device = 0
|
||||||
|
dst_device = 1
|
||||||
|
torch.xpu.set_device(src_device)
|
||||||
|
src_prev_stream = torch.xpu.current_stream(src_device)
|
||||||
|
dst_prev_stream = torch.xpu.current_stream(dst_device)
|
||||||
|
with torch.xpu.Stream(dst_device) as dst_stream:
|
||||||
|
self.assertEqual(dst_device, torch.xpu.current_device())
|
||||||
|
self.assertEqual(dst_stream, torch.xpu.current_stream())
|
||||||
|
self.assertEqual(src_prev_stream, torch.xpu.current_stream(src_device))
|
||||||
|
self.assertEqual(src_device, torch.xpu.current_device())
|
||||||
|
self.assertEqual(src_prev_stream, torch.xpu.current_stream())
|
||||||
|
self.assertEqual(dst_prev_stream, torch.xpu.current_stream(dst_device))
|
||||||
|
|
||||||
def test_generator(self):
|
def test_generator(self):
|
||||||
torch.manual_seed(2024)
|
torch.manual_seed(2024)
|
||||||
g_state0 = torch.xpu.get_rng_state()
|
g_state0 = torch.xpu.get_rng_state()
|
||||||
|
@ -13215,7 +13215,8 @@ Stream(device, *, priority) -> Stream
|
|||||||
|
|
||||||
An in-order queue of executing the respective tasks asynchronously in first in first out (FIFO) order.
|
An in-order queue of executing the respective tasks asynchronously in first in first out (FIFO) order.
|
||||||
It can control or synchronize the execution of other Stream or block the current host thread to ensure
|
It can control or synchronize the execution of other Stream or block the current host thread to ensure
|
||||||
the correct task sequencing.
|
the correct task sequencing. It supports with statement as a context manager to ensure the operators
|
||||||
|
within the with block are running on the corresponding stream.
|
||||||
|
|
||||||
See in-depth description of the CUDA behavior at :ref:`cuda-semantics` for details
|
See in-depth description of the CUDA behavior at :ref:`cuda-semantics` for details
|
||||||
on the exact semantic that applies to all devices.
|
on the exact semantic that applies to all devices.
|
||||||
@ -13232,7 +13233,10 @@ Returns:
|
|||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||||
>>> s_cuda = torch.Stream(device='cuda')
|
>>> with torch.Stream(device='cuda') as s_cuda:
|
||||||
|
>>> a = torch.randn(10, 5, device='cuda')
|
||||||
|
>>> b = torch.randn(5, 10, device='cuda')
|
||||||
|
>>> c = torch.mm(a, b)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -94,6 +94,7 @@ static PyObject* THPStream_pynew(
|
|||||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||||
self->device_index = static_cast<int64_t>(stream_opt->device_index());
|
self->device_index = static_cast<int64_t>(stream_opt->device_index());
|
||||||
self->device_type = static_cast<int64_t>(stream_opt->device_type());
|
self->device_type = static_cast<int64_t>(stream_opt->device_type());
|
||||||
|
self->context = nullptr;
|
||||||
|
|
||||||
return (PyObject*)ptr.release();
|
return (PyObject*)ptr.release();
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
@ -112,6 +113,7 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
|
|||||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||||
self->device_index = static_cast<int64_t>(stream.device_index());
|
self->device_index = static_cast<int64_t>(stream.device_index());
|
||||||
self->device_type = static_cast<int64_t>(stream.device_type());
|
self->device_type = static_cast<int64_t>(stream.device_type());
|
||||||
|
self->context = nullptr;
|
||||||
return ptr.release();
|
return ptr.release();
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
@ -256,6 +258,89 @@ static PyObject* THPStream_eq(THPStream* self, THPStream* other) {
|
|||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* THPStream_enter(PyObject* _self, PyObject* unused) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
auto self = (THPStream*)_self;
|
||||||
|
c10::DeviceType stream_device_type =
|
||||||
|
static_cast<c10::DeviceType>(self->device_type);
|
||||||
|
// No operation is performed if the stream does not belong to an accelerator.
|
||||||
|
if (C10_UNLIKELY(!at::accelerator::isAccelerator(stream_device_type))) {
|
||||||
|
Py_INCREF(_self);
|
||||||
|
return _self;
|
||||||
|
}
|
||||||
|
c10::DeviceIndex cur_device_idx = at::accelerator::getDeviceIndex();
|
||||||
|
c10::DeviceIndex stream_device_idx =
|
||||||
|
static_cast<c10::DeviceIndex>(self->device_index);
|
||||||
|
// If the stream is not on the current device, switch the current device to
|
||||||
|
// the device of the stream.
|
||||||
|
if (stream_device_idx != cur_device_idx) {
|
||||||
|
at::accelerator::setDeviceIndex(stream_device_idx);
|
||||||
|
}
|
||||||
|
c10::Stream cur_stream = at::accelerator::getCurrentStream(stream_device_idx);
|
||||||
|
at::accelerator::setCurrentStream(c10::Stream::unpack3(
|
||||||
|
self->stream_id, stream_device_idx, stream_device_type));
|
||||||
|
// Save the current device index and previous stream to the context.
|
||||||
|
auto ctx_device_index =
|
||||||
|
THPObjectPtr(THPUtils_packDeviceIndex(cur_device_idx));
|
||||||
|
auto ctx_stream = THPObjectPtr(THPStream_Wrap(cur_stream));
|
||||||
|
TORCH_CHECK(!(self->context), "Stream's context should not be initialized.");
|
||||||
|
auto dict = THPObjectPtr(PyDict_New());
|
||||||
|
if (!dict) {
|
||||||
|
throw python_error();
|
||||||
|
}
|
||||||
|
self->context = dict.release();
|
||||||
|
if (PyDict_SetItemString(
|
||||||
|
self->context, "_ctx_device_index", ctx_device_index.get()) < 0) {
|
||||||
|
throw python_error();
|
||||||
|
}
|
||||||
|
if (PyDict_SetItemString(self->context, "_ctx_stream", ctx_stream.get()) <
|
||||||
|
0) {
|
||||||
|
throw python_error();
|
||||||
|
}
|
||||||
|
Py_INCREF(_self);
|
||||||
|
return _self;
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyObject* THPStream_exit(PyObject* _self, PyObject* unused) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
auto self = (THPStream*)_self;
|
||||||
|
// No operation is performed if the stream does not belong to an accelerator.
|
||||||
|
if (C10_UNLIKELY(!at::accelerator::isAccelerator(
|
||||||
|
static_cast<c10::DeviceType>(self->device_type)))) {
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
PyObject* py_stream = nullptr;
|
||||||
|
if (PyDict_GetItemStringRef(self->context, "_ctx_stream", &py_stream) < 0) {
|
||||||
|
throw python_error();
|
||||||
|
}
|
||||||
|
auto ctx_stream = THPObjectPtr(py_stream);
|
||||||
|
PyObject* py_device_index = nullptr;
|
||||||
|
if (PyDict_GetItemStringRef(
|
||||||
|
self->context, "_ctx_device_index", &py_device_index) < 0) {
|
||||||
|
throw python_error();
|
||||||
|
}
|
||||||
|
auto ctx_device_index = THPObjectPtr(py_device_index);
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
ctx_stream.get(), "ctx_stream should be present on the context dict.");
|
||||||
|
auto prev_stream = (THPStream*)(ctx_stream.get());
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
ctx_device_index.get(),
|
||||||
|
"ctx_device_index should be present on the context dict.");
|
||||||
|
auto prev_device_index = THPUtils_unpackDeviceIndex(ctx_device_index.get());
|
||||||
|
at::accelerator::setCurrentStream(c10::Stream::unpack3(
|
||||||
|
prev_stream->stream_id,
|
||||||
|
static_cast<c10::DeviceIndex>(prev_stream->device_index),
|
||||||
|
static_cast<c10::DeviceType>(prev_stream->device_type)));
|
||||||
|
// Reset the current device to the previous device if they differ.
|
||||||
|
if (static_cast<c10::DeviceIndex>(self->device_index) != prev_device_index) {
|
||||||
|
at::accelerator::setDeviceIndex(prev_device_index);
|
||||||
|
}
|
||||||
|
Py_CLEAR(self->context);
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
static PyObject* THPStream_ne(THPStream* self, THPStream* other) {
|
static PyObject* THPStream_ne(THPStream* self, THPStream* other) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
return PyBool_FromLong(
|
return PyBool_FromLong(
|
||||||
@ -321,6 +406,8 @@ static const std::initializer_list<PyMethodDef> THPStream_methods = {
|
|||||||
METH_VARARGS | METH_KEYWORDS,
|
METH_VARARGS | METH_KEYWORDS,
|
||||||
nullptr},
|
nullptr},
|
||||||
{"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
|
{"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
|
||||||
|
{"__enter__", THPStream_enter, METH_NOARGS, nullptr},
|
||||||
|
{"__exit__", THPStream_exit, METH_VARARGS, nullptr},
|
||||||
{nullptr}};
|
{nullptr}};
|
||||||
|
|
||||||
static PyTypeObject THPStreamType = {
|
static PyTypeObject THPStreamType = {
|
||||||
|
@ -10,6 +10,8 @@ struct THPStream {
|
|||||||
int64_t stream_id;
|
int64_t stream_id;
|
||||||
int64_t device_type;
|
int64_t device_type;
|
||||||
int64_t device_index;
|
int64_t device_index;
|
||||||
|
// Used to switch stream context management, initialized lazily.
|
||||||
|
PyObject* context;
|
||||||
};
|
};
|
||||||
extern TORCH_API PyTypeObject* THPStreamClass;
|
extern TORCH_API PyTypeObject* THPStreamClass;
|
||||||
|
|
||||||
|
@ -15,8 +15,9 @@ class Stream(torch._C._CudaStreamBase):
|
|||||||
r"""Wrapper around a CUDA stream.
|
r"""Wrapper around a CUDA stream.
|
||||||
|
|
||||||
A CUDA stream is a linear sequence of execution that belongs to a specific
|
A CUDA stream is a linear sequence of execution that belongs to a specific
|
||||||
device, independent from other streams. See :ref:`cuda-semantics` for
|
device, independent from other streams. It supports with statement as a
|
||||||
details.
|
context manager to ensure the operators within the with block are running
|
||||||
|
on the corresponding stream. See :ref:`cuda-semantics` for details.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device(torch.device or int, optional): a device on which to allocate
|
device(torch.device or int, optional): a device on which to allocate
|
||||||
|
@ -15,7 +15,9 @@ class Stream(torch._C._XpuStreamBase):
|
|||||||
r"""Wrapper around a XPU stream.
|
r"""Wrapper around a XPU stream.
|
||||||
|
|
||||||
A XPU stream is a linear sequence of execution that belongs to a specific
|
A XPU stream is a linear sequence of execution that belongs to a specific
|
||||||
device, independent from other streams.
|
device, independent from other streams. It supports with statement as a
|
||||||
|
context manager to ensure the operators within the with block are running
|
||||||
|
on the corresponding stream.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device(torch.device or int, optional): a device on which to allocate
|
device(torch.device or int, optional): a device on which to allocate
|
||||||
|
Reference in New Issue
Block a user