Add query and synchronize to c10::Stream (#59560)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59560

`at::cuda::CUDAStream` has the `query` and `synchronize` methods, but `c10::Stream` does not, and I couldn't find any generic way to accomplish this. Hence I added helpers to do this to the DeviceGuardImpl interface, and then defined these methods on `c10::Stream`. (I had to do it out-of-line to circumvent a circular dependency).
ghstack-source-id: 130932249

Test Plan: CI

Reviewed By: ezyang

Differential Revision: D28931377

fbshipit-source-id: cd0c19cf021e305d0c0cf9af364afb445d010248
This commit is contained in:
Luca Wehrstedt
2021-06-10 01:41:25 -07:00
committed by Facebook GitHub Bot
parent f11120967e
commit e7cccc23b9
6 changed files with 78 additions and 1 deletions

View File

@ -190,6 +190,17 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
return (err == hipSuccess);
}
// Stream-related functions
bool queryStream(const Stream& stream) const override {
HIPStreamMasqueradingAsCUDA hip_stream{stream};
return hip_stream.query();
}
void synchronizeStream(const Stream& stream) const override {
HIPStreamMasqueradingAsCUDA hip_stream{stream};
hip_stream.synchronize();
}
void recordDataPtrOnStream(
const c10::DataPtr& data_ptr,
const Stream& stream) const override {

View File

@ -1,7 +1,22 @@
#include <c10/core/Stream.h>
#include <c10/core/impl/VirtualGuardImpl.h>
namespace c10 {
// Return whether all asynchronous work previously enqueued on this stream
// has completed running on the device.
bool Stream::query() const {
impl::VirtualGuardImpl impl{device_.type()};
return impl.queryStream(*this);
}
// Wait (by blocking the calling thread) until all asynchronous work enqueued
// on this stream has completed running on the device.
void Stream::synchronize() const {
impl::VirtualGuardImpl impl{device_.type()};
impl.synchronizeStream(*this);
}
// Not very parsable, but I don't know a good compact syntax for streams.
// Feel free to change this into something more compact if needed.
std::ostream& operator<<(std::ostream& stream, const Stream& s) {

View File

@ -54,7 +54,7 @@ using StreamId = int32_t;
* functionality (e.g., get the cudaStream_t of a CUDA stream.) There are
* wrapper classes which provide this functionality, e.g., CUDAStream.
*/
class Stream final {
class C10_API Stream final {
private:
Device device_;
StreamId id_;
@ -107,6 +107,14 @@ class Stream final {
event.block(*this);
}
// Return whether all asynchronous work previously enqueued on this stream
// has completed running on the device.
bool query() const;
// Wait (by blocking the calling thread) until all asynchronous work enqueued
// on this stream has completed running on the device.
void synchronize() const;
// The purpose of this function is to more conveniently permit binding
// of Stream to and from Python. Without packing, I have to setup a whole
// class with two fields (device and stream id); with packing I can just

View File

@ -176,6 +176,22 @@ struct C10_API DeviceGuardImplInterface {
*/
virtual DeviceIndex deviceCount() const noexcept = 0;
/**
* Return true if all the work previously enqueued on the stream for
* asynchronous execution has completed running on the device.
*/
virtual bool queryStream(const Stream& stream) const {
TORCH_CHECK(false, "Backend doesn't support querying streams.");
}
/**
* Wait (by blocking the calling thread) until all the work previously
* enqueued on the stream has completed running on the device.
*/
virtual void synchronizeStream(const Stream& stream) const {
TORCH_CHECK(false, "Backend doesn't support synchronizing streams.");
}
/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
@ -241,6 +257,14 @@ struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
}
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {}
// Stream-related functions
bool queryStream(const Stream& stream) const override {
return true;
}
void synchronizeStream(const Stream& stream) const override {
// Don't wait for anything.
}
};
// The registry is NON-owning. Each stored pointer is std::atomic so

View File

@ -69,6 +69,13 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface {
impl_->destroyEvent(event, device_index);
}
bool queryStream(const Stream& stream) const override {
return impl_->queryStream(stream);
}
void synchronizeStream(const Stream& stream) const override {
impl_->synchronizeStream(stream);
}
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
impl_->recordDataPtrOnStream(data_ptr, stream);

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/DeviceGuard.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
@ -170,6 +171,17 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
return (err == cudaSuccess);
}
// Stream-related functions
bool queryStream(const Stream& stream) const override {
CUDAStream cuda_stream{stream};
return cuda_stream.query();
}
void synchronizeStream(const Stream& stream) const override {
CUDAStream cuda_stream{stream};
cuda_stream.synchronize();
}
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
CUDAStream cuda_stream{stream};