mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[Intel GPU] Support getStreamFromExternel for XPU. (#140268)
In AOT inductor scenario, the GPU Stream can be created outside of the pool of `XPUStream`, and we need to create a `XPUStream` which refers to this stream for the the common logic of AOTI, for example a stream guard is a guard for `XPUStream`. So we add the getStreamFromExternel following the design of CUDAStream. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140268 Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/EikanWang
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							843018f407
						
					
				
				
					commit
					3d227ae315
				
			| @ -37,21 +37,28 @@ thread_local std::unique_ptr<StreamId[]> current_streams = nullptr; | |||||||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~ | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
| // How do we assign stream IDs? | // How do we assign stream IDs? | ||||||
| // | // | ||||||
| // -- 57 bits --  -- 5 bits -----  -- 3 bits -- | // -- 56 bits --  -- 5 bits -----  -- 3 bits --  -- 1 bits -- | ||||||
| //     zeros      StreamIdIndex    StreamIdType | //     zeros      StreamIdIndex    StreamIdType  Ext/native stream | ||||||
| // | // | ||||||
| // Where StreamIdType: | // Where StreamIdType: | ||||||
| //  000 = normal priority queue | //  000 = normal priority queue | ||||||
| //  001 = high priority queue | //  001 = high priority queue | ||||||
| // | // | ||||||
|  | // for external stream, StreamID is a sycl::queue* pointer | ||||||
|  | // this means that last bit will always be 0 | ||||||
|  | // so when constructing StreamId for a native stream we set last bit to 1 | ||||||
|  | // to distinguish between native and external streams | ||||||
|  | // | ||||||
| // StreamId is 64-bit, so we can just rely on regular promotion rules. | // StreamId is 64-bit, so we can just rely on regular promotion rules. | ||||||
| // We rely on StreamIdIndex and StreamIdType being non-negative; | // We rely on StreamIdIndex and StreamIdType being non-negative; | ||||||
|  |  | ||||||
| using StreamIdIndex = uint8_t; | using StreamIdIndex = uint8_t; | ||||||
| enum class StreamIdType : uint8_t { | enum class StreamIdType : uint8_t { | ||||||
|   // The higher the type number, the higher the priority. |   // The higher the type number, the higher the priority. | ||||||
|  |   // EXT is used for external streams, which we don't know the priority of. | ||||||
|   NORMAL = 0x0, |   NORMAL = 0x0, | ||||||
|   HIGH = 0X1, |   HIGH = 0X1, | ||||||
|  |   EXT = 0x7, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) { | inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) { | ||||||
| @ -60,6 +67,8 @@ inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) { | |||||||
|       return stream << "NORMAL"; |       return stream << "NORMAL"; | ||||||
|     case StreamIdType::HIGH: |     case StreamIdType::HIGH: | ||||||
|       return stream << "HIGH"; |       return stream << "HIGH"; | ||||||
|  |     case StreamIdType::EXT: | ||||||
|  |       return stream << "EXT"; | ||||||
|     default: |     default: | ||||||
|       break; |       break; | ||||||
|   } |   } | ||||||
| @ -67,8 +76,13 @@ inline std::ostream& operator<<(std::ostream& stream, StreamIdType q) { | |||||||
| } | } | ||||||
|  |  | ||||||
| inline StreamIdType streamIdType(StreamId s) { | inline StreamIdType streamIdType(StreamId s) { | ||||||
|  |   // Externally allocated streams have their id being the sycl:queue* pointer | ||||||
|  |   // so the last bit will be 0 | ||||||
|  |   if ((!(s & 1) && s)) { | ||||||
|  |     return StreamIdType(StreamIdType::EXT); | ||||||
|  |   } | ||||||
|   int mask_for_type = (1 << kStreamTypeBits) - 1; |   int mask_for_type = (1 << kStreamTypeBits) - 1; | ||||||
|   auto st = static_cast<StreamIdType>(s & mask_for_type); |   auto st = static_cast<StreamIdType>((s >> 1) & mask_for_type); | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|       st == StreamIdType::NORMAL || st == StreamIdType::HIGH, |       st == StreamIdType::NORMAL || st == StreamIdType::HIGH, | ||||||
|       "invalid StreamId: ", |       "invalid StreamId: ", | ||||||
| @ -78,12 +92,12 @@ inline StreamIdType streamIdType(StreamId s) { | |||||||
|  |  | ||||||
| inline StreamIdIndex streamIdIndex(StreamId s) { | inline StreamIdIndex streamIdIndex(StreamId s) { | ||||||
|   return static_cast<StreamIdIndex>( |   return static_cast<StreamIdIndex>( | ||||||
|       (s >> kStreamTypeBits) & ((1 << kStreamsPerPoolBits) - 1)); |       (s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1)); | ||||||
| } | } | ||||||
|  |  | ||||||
| inline StreamId makeStreamId(StreamIdType st, StreamIdIndex si) { | inline StreamId makeStreamId(StreamIdType st, StreamIdIndex si) { | ||||||
|   return (static_cast<StreamId>(si) << kStreamTypeBits) | |   return (static_cast<StreamId>(si) << (kStreamTypeBits + 1)) | | ||||||
|       static_cast<StreamId>(st); |       (static_cast<StreamId>(st) << 1) | 1; | ||||||
| } | } | ||||||
|  |  | ||||||
| void initGlobalStreamState() { | void initGlobalStreamState() { | ||||||
| @ -166,6 +180,14 @@ XPUStream XPUStreamForId(DeviceIndex device_index, StreamId stream_id) { | |||||||
| int XPUStream::priority() const { | int XPUStream::priority() const { | ||||||
|   StreamId stream_id = stream_.id(); |   StreamId stream_id = stream_.id(); | ||||||
|   StreamIdType st = streamIdType(stream_id); |   StreamIdType st = streamIdType(stream_id); | ||||||
|  |   // For an external queue which is not created in XPUStream, we can not trace | ||||||
|  |   // the priority. Workaround here since sycl doesn't support get priority from | ||||||
|  |   // a sycl::queue, like cudaStreamGetPriority . | ||||||
|  |   // TODO: remove this workaround when sycl supports get priority from a | ||||||
|  |   // sycl::queue. | ||||||
|  |   if (st == StreamIdType::EXT) { | ||||||
|  |     st = StreamIdType::NORMAL; | ||||||
|  |   } | ||||||
|   // StreamIdType and priority number are inversely related. |   // StreamIdType and priority number are inversely related. | ||||||
|   return -static_cast<int>(st); |   return -static_cast<int>(st); | ||||||
| } | } | ||||||
| @ -177,6 +199,8 @@ sycl::queue& XPUStream::queue() const { | |||||||
|   StreamIdType st = streamIdType(stream_id); |   StreamIdType st = streamIdType(stream_id); | ||||||
|   StreamIdIndex si = streamIdIndex(stream_id); |   StreamIdIndex si = streamIdIndex(stream_id); | ||||||
|   switch (st) { |   switch (st) { | ||||||
|  |     case StreamIdType::EXT: | ||||||
|  |       return *(reinterpret_cast<sycl::queue*>(stream_id)); | ||||||
|     case StreamIdType::NORMAL: |     case StreamIdType::NORMAL: | ||||||
|     case StreamIdType::HIGH: |     case StreamIdType::HIGH: | ||||||
|       return *streams[device_index][static_cast<uint8_t>(st)][si]; |       return *streams[device_index][static_cast<uint8_t>(st)][si]; | ||||||
| @ -221,6 +245,15 @@ XPUStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) { | |||||||
|   return getStreamFromPool(priority, device); |   return getStreamFromPool(priority, device); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | XPUStream getStreamFromExternal( | ||||||
|  |     sycl::queue* ext_stream, | ||||||
|  |     DeviceIndex device_index) { | ||||||
|  |   // The sycl::queue* will be the actual id | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(ext_stream, "External stream must not be a nullptr."); | ||||||
|  |   return XPUStreamForId(device_index, reinterpret_cast<int64_t>(ext_stream)); | ||||||
|  | } | ||||||
|  |  | ||||||
| // Note: The stream pools will be initialized if needed, at the first invocation | // Note: The stream pools will be initialized if needed, at the first invocation | ||||||
| // to this function. | // to this function. | ||||||
| XPUStream getCurrentXPUStream(DeviceIndex device) { | XPUStream getCurrentXPUStream(DeviceIndex device) { | ||||||
|  | |||||||
| @ -157,6 +157,16 @@ getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); | |||||||
| C10_XPU_API XPUStream | C10_XPU_API XPUStream | ||||||
| getStreamFromPool(const int priority, DeviceIndex device = -1); | getStreamFromPool(const int priority, DeviceIndex device = -1); | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Get a XPUStream from a externally allocated one. | ||||||
|  |  * | ||||||
|  |  * This is mainly for interoperability with different libraries where we | ||||||
|  |  * want to operate on a non-torch allocated stream for data exchange or similar | ||||||
|  |  * purposes | ||||||
|  |  */ | ||||||
|  | C10_API XPUStream | ||||||
|  | getStreamFromExternal(sycl::queue* ext_stream, DeviceIndex device_index); | ||||||
|  |  | ||||||
| /** | /** | ||||||
|  * Get the current XPU stream, for the passed XPU device, or for the current |  * Get the current XPU stream, for the passed XPU device, or for the current | ||||||
|  * device if no device index is passed. |  * device if no device index is passed. | ||||||
|  | |||||||
| @ -1,8 +1,11 @@ | |||||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||||
|  |  | ||||||
|  | #include <c10/core/DeviceGuard.h> | ||||||
| #include <c10/util/irange.h> | #include <c10/util/irange.h> | ||||||
|  | #include <c10/xpu/XPUException.h> | ||||||
| #include <c10/xpu/XPUStream.h> | #include <c10/xpu/XPUStream.h> | ||||||
| #include <c10/xpu/test/impl/XPUTest.h> | #include <c10/xpu/test/impl/XPUTest.h> | ||||||
|  |  | ||||||
| #include <optional> | #include <optional> | ||||||
|  |  | ||||||
| #include <thread> | #include <thread> | ||||||
| @ -178,3 +181,70 @@ TEST(XPUStreamTest, StreamFunction) { | |||||||
|   validateHostData(hostData, numel); |   validateHostData(hostData, numel); | ||||||
|   sycl::free(deviceData, c10::xpu::get_device_context()); |   sycl::free(deviceData, c10::xpu::get_device_context()); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Verifies external streams can be created and used | ||||||
|  | TEST(XPUStreamTest, ExternalTest) { | ||||||
|  |   if (!has_xpu()) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 0)); | ||||||
|  |  | ||||||
|  |   using namespace sycl::ext::oneapi::property; | ||||||
|  |   sycl::queue* stream = new sycl::queue( | ||||||
|  |       c10::xpu::get_device_context(), | ||||||
|  |       c10::xpu::get_raw_device(0), | ||||||
|  |       c10::xpu::asyncHandler, | ||||||
|  |       {sycl::property::queue::in_order(), queue::priority_normal()}); | ||||||
|  |  | ||||||
|  |   at::xpu::XPUStream myStream = at::xpu::getStreamFromExternal(stream, 0); | ||||||
|  |  | ||||||
|  |   at::xpu::setCurrentXPUStream(myStream); | ||||||
|  |   at::xpu::XPUStream curStream = at::xpu::getCurrentXPUStream(); | ||||||
|  |  | ||||||
|  |   ASSERT_TRUE(curStream == myStream); | ||||||
|  |   ASSERT_TRUE(&(curStream.queue()) == stream); | ||||||
|  |  | ||||||
|  |   delete stream; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Verifies different external streams can be used for different devices at the | ||||||
|  | // same time | ||||||
|  | TEST(XPUStreamTest, ExternalMultiDeviceTest) { | ||||||
|  |   if (!has_xpu()) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |   if (c10::xpu::device_count() < 2) | ||||||
|  |     return; | ||||||
|  |   sycl::queue* stream_0 = nullptr; | ||||||
|  |   sycl::queue* stream_1 = nullptr; | ||||||
|  |  | ||||||
|  |   using namespace sycl::ext::oneapi::property; | ||||||
|  |   { | ||||||
|  |     c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 0)); | ||||||
|  |     stream_0 = new sycl::queue( | ||||||
|  |         c10::xpu::get_device_context(), | ||||||
|  |         c10::xpu::get_raw_device(0), | ||||||
|  |         c10::xpu::asyncHandler, | ||||||
|  |         {sycl::property::queue::in_order(), queue::priority_normal()}); | ||||||
|  |   } | ||||||
|  |   { | ||||||
|  |     c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, 1)); | ||||||
|  |     stream_0 = new sycl::queue( | ||||||
|  |         c10::xpu::get_device_context(), | ||||||
|  |         c10::xpu::get_raw_device(1), | ||||||
|  |         c10::xpu::asyncHandler, | ||||||
|  |         {sycl::property::queue::in_order(), queue::priority_normal()}); | ||||||
|  |   } | ||||||
|  |   at::xpu::XPUStream myStream0 = at::xpu::getStreamFromExternal(stream_0, 0); | ||||||
|  |   at::xpu::XPUStream myStream1 = at::xpu::getStreamFromExternal(stream_1, 1); | ||||||
|  |  | ||||||
|  |   at::xpu::setCurrentXPUStream(myStream0); | ||||||
|  |   ASSERT_TRUE(at::xpu::getCurrentXPUStream(0) == myStream0); | ||||||
|  |   at::xpu::setCurrentXPUStream(myStream1); | ||||||
|  |   ASSERT_TRUE(at::xpu::getCurrentXPUStream(0) == myStream0); | ||||||
|  |   ASSERT_TRUE(at::xpu::getCurrentXPUStream(1) == myStream1); | ||||||
|  |  | ||||||
|  |   delete stream_0; | ||||||
|  |   delete stream_1; | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user