mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[CUDA] Add experimental green context support for SM carveout (#159104)
Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here... Built on top of @drisspg 's branch Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104 Approved by: https://github.com/ngimel Co-authored-by: drisspg <drisspguessous@gmail.com>
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							7eb1eb4313
						
					
				
				
					commit
					3c59351c6e
				
			| @ -51,6 +51,17 @@ | ||||
|  | ||||
| #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) | ||||
| #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ | ||||
|   _(cuCtxFromGreenCtx, 12080)              \ | ||||
|   _(cuCtxGetCurrent, 12080)                \ | ||||
|   _(cuCtxPopCurrent, 12080)                \ | ||||
|   _(cuCtxPushCurrent, 12080)               \ | ||||
|   _(cuCtxSetCurrent, 12080)                \ | ||||
|   _(cuGreenCtxCreate, 12080)               \ | ||||
|   _(cuGreenCtxDestroy, 12080)              \ | ||||
|   _(cuDevSmResourceSplitByCount, 12080)    \ | ||||
|   _(cuDeviceGet, 12080)                    \ | ||||
|   _(cuDeviceGetDevResource, 12080)         \ | ||||
|   _(cuDevResourceGenerateDesc, 12080)      \ | ||||
|   _(cuMulticastAddDevice, 12030)           \ | ||||
|   _(cuMulticastBindMem, 12030)             \ | ||||
|   _(cuMulticastCreate, 12030)              \ | ||||
|  | ||||
| @ -262,6 +262,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t | ||||
|  | ||||
| ``` | ||||
|  | ||||
| ## Green Contexts (experimental) | ||||
|  | ||||
| `torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs | ||||
| to enable more general carveout of SM resources for CUDA kernels. | ||||
|  | ||||
| These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8. | ||||
|  | ||||
| See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these. | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. currentmodule:: torch.cuda.green_contexts | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. autosummary:: | ||||
|     :toctree: generated | ||||
|     :nosignatures: | ||||
|  | ||||
|     GreenContext | ||||
| ``` | ||||
|  | ||||
|  | ||||
| % This module needs to be documented. Adding here in the meantime | ||||
|  | ||||
| % for tracking purposes | ||||
| @ -274,6 +296,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t | ||||
| .. py:module:: torch.cuda.gds | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. py:module:: torch.cuda.green_contexts | ||||
| ``` | ||||
|  | ||||
| ```{eval-rst} | ||||
| .. py:module:: torch.cuda.jiterator | ||||
| ``` | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| # Owner(s): ["module: linear algebra"] | ||||
|  | ||||
| import contextlib | ||||
| import time | ||||
| import unittest | ||||
| from itertools import product | ||||
| from functools import partial | ||||
| @ -846,6 +847,28 @@ class TestMatmulCuda(InductorTestCase): | ||||
|                     op(a, mismatch_batch_dim_b, out_dtype=torch.float32) | ||||
|  | ||||
|  | ||||
|     @unittest.skipIf(not _get_torch_cuda_version() >= (12, 8), "Green Context only tested on 12.8+") | ||||
|     def test_greencontext_carveout(self): | ||||
|         a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16) | ||||
|         ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) | ||||
|         ctx.make_current() | ||||
|         torch.matmul(a, a) | ||||
|         torch.cuda.synchronize() | ||||
|         t0 = time.perf_counter() | ||||
|         partial_res = torch.matmul(a, a) | ||||
|         torch.cuda.synchronize() | ||||
|         t1 = time.perf_counter() | ||||
|         ctx.pop_current() | ||||
|         torch.matmul(a, a) | ||||
|         torch.cuda.synchronize() | ||||
|         t2 = time.perf_counter() | ||||
|         full_res = torch.matmul(a, a) | ||||
|         torch.cuda.synchronize() | ||||
|         t3 = time.perf_counter() | ||||
|         self.assertEqual(partial_res, full_res) | ||||
|         self.assertGreater(t1 - t0, t3 - t2) | ||||
|  | ||||
|  | ||||
| @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") | ||||
| @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") | ||||
| @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") | ||||
|  | ||||
| @ -123,6 +123,7 @@ class TestPublicBindings(TestCase): | ||||
|             "FutureType", | ||||
|             "Generator", | ||||
|             "GeneratorType", | ||||
|             "GreenContext", | ||||
|             "get_autocast_cpu_dtype", | ||||
|             "get_autocast_dtype", | ||||
|             "get_autocast_ipu_dtype", | ||||
|  | ||||
| @ -2777,3 +2777,17 @@ class _StaticCudaLauncher: | ||||
|         args: tuple[Any, ...], | ||||
|         stream: _int, | ||||
|     ) -> None: ... | ||||
|  | ||||
| # Defined in torch/csrc/cuda/green_context.h | ||||
| class GreenContext: | ||||
|     @staticmethod | ||||
|     def create( | ||||
|         num_sms: _int, | ||||
|         device_id: _int, | ||||
|         ) -> GreenContext: ... | ||||
|     def make_current( | ||||
|         self, | ||||
|         ) -> None: ... | ||||
|     def pop_current( | ||||
|         self, | ||||
|         ) -> None: ... | ||||
|  | ||||
| @ -23,6 +23,7 @@ | ||||
| #include <c10/cuda/CUDAAllocatorConfig.h> | ||||
| #include <c10/cuda/CUDACachingAllocator.h> | ||||
| #include <c10/cuda/CUDAFunctions.h> | ||||
| #include <torch/csrc/cuda/green_context.h> | ||||
| #include <ATen/cuda/CUDAGraphsUtils.cuh> | ||||
|  | ||||
| #ifdef USE_NCCL | ||||
| @ -1491,6 +1492,13 @@ static void registerCudaPluggableAllocator(PyObject* module) { | ||||
|         addStorageDeleterFns(storages_to_add_deleters_to, delta); | ||||
|       }); | ||||
| } | ||||
| static void initGreenContext(PyObject* module) { | ||||
|   auto m = py::handle(module).cast<py::module>(); | ||||
|   py::class_<GreenContext>(m, "GreenContext") | ||||
|       .def_static("create", &GreenContext::create) | ||||
|       .def("make_current", &GreenContext::makeCurrent) | ||||
|       .def("pop_current", &GreenContext::popCurrent); | ||||
| } | ||||
|  | ||||
| static void bindGetDeviceProperties(PyObject* module) { | ||||
|   // Add method to torch.cuda | ||||
| @ -2215,6 +2223,7 @@ void initModule(PyObject* module) { | ||||
|   registerCudaDeviceProperties(module); | ||||
|   registerCudaPluggableAllocator(module); | ||||
|   initCudaMethodBindings(module); | ||||
|   initGreenContext(module); | ||||
| } | ||||
|  | ||||
| } // namespace torch::cuda | ||||
|  | ||||
							
								
								
									
										213
									
								
								torch/csrc/cuda/green_context.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								torch/csrc/cuda/green_context.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,213 @@ | ||||
| #pragma once | ||||
| #include <ATen/cuda/CUDAEvent.h> | ||||
| #if defined(CUDA_VERSION) && !defined(USE_ROCM) | ||||
| #include <c10/cuda/driver_api.h> | ||||
| #include <cuda.h> | ||||
| #include <memory> | ||||
| #include <stdexcept> | ||||
| #include <vector> | ||||
| #endif | ||||
|  | ||||
| class GreenContext { | ||||
|  public: | ||||
|   GreenContext(int device_id, unsigned int num_sms) { | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|     int driver_version; | ||||
|     C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); | ||||
|     TORCH_CHECK( | ||||
|         driver_version >= 12080, "cuda driver too old to use green context!"); | ||||
|     CUcontext pctx = nullptr; | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); | ||||
|     if (C10_UNLIKELY(!pctx)) { | ||||
|       TORCH_WARN( | ||||
|           "Attempted to create a green context but" | ||||
|           " there was no primary context! Creating a primary context..."); | ||||
|  | ||||
|       cudaFree(0); | ||||
|     } | ||||
|  | ||||
|     CUdevice device; | ||||
|     device_id_ = device_id; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); | ||||
|  | ||||
|     // Get device resources | ||||
|     CUdevResource device_resource; | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( | ||||
|         device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); | ||||
|  | ||||
|     // Split resources | ||||
|     std::vector<CUdevResource> result(1); | ||||
|     auto result_data = result.data(); | ||||
|     unsigned int nb_groups = 1; | ||||
|     CUdevResource remaining; | ||||
|  | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( | ||||
|             result_data, | ||||
|             &nb_groups, | ||||
|             &device_resource, | ||||
|             &remaining, | ||||
|             0, // default flags | ||||
|             num_sms)); | ||||
|  | ||||
|     TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); | ||||
|  | ||||
|     // Generate resource descriptor | ||||
|     CUdevResourceDesc desc; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( | ||||
|             &desc, result_data, 1)); | ||||
|  | ||||
|     // Create green context | ||||
|     // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: | ||||
|     // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( | ||||
|         &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||||
|  | ||||
|     // Convert to regular context | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); | ||||
|     TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   static std::unique_ptr<GreenContext> create( | ||||
|       unsigned int num_sms, | ||||
|       std::optional<unsigned int> device_id) { | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|     if (!device_id.has_value()) { | ||||
|       device_id = at::cuda::current_device(); | ||||
|     } | ||||
|     return std::make_unique<GreenContext>(device_id.value(), num_sms); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Delete copy constructor and assignment | ||||
|   GreenContext(const GreenContext&) = delete; | ||||
|   GreenContext& operator=(const GreenContext&) = delete; | ||||
|  | ||||
|   // Implement move operations | ||||
|   GreenContext(GreenContext&& other) noexcept { | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|     device_id_ = std::exchange(other.device_id_, -1); | ||||
|     green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||
|     context_ = std::exchange(other.context_, nullptr); | ||||
|     parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   GreenContext& operator=(GreenContext&& other) noexcept { | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|     if (this != &other) { | ||||
|       // Clean up current resources | ||||
|       if (green_ctx_) { | ||||
|         CUcontext current = nullptr; | ||||
|         C10_CUDA_DRIVER_CHECK( | ||||
|             c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||
|         if (current == context_) { | ||||
|           TORCH_CHECK( | ||||
|               false, | ||||
|               "attempting to overwrite current green ctx " | ||||
|               "when it is active!"); | ||||
|         } | ||||
|         C10_CUDA_DRIVER_CHECK(cuGreenCtxDestroy(green_ctx_)); | ||||
|       } | ||||
|  | ||||
|       // Take ownership of other's resources | ||||
|       device_id_ = std::exchange(other.device_id_, -1); | ||||
|       green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||
|       context_ = std::exchange(other.context_, nullptr); | ||||
|       parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||
|     } | ||||
|     return *this; | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   ~GreenContext() noexcept { | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Get the underlying CUDA context | ||||
|   CUcontext getContext() const { | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|     return context_; | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Get the underlying green context | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|   CUgreenCtx getGreenContext() const { | ||||
|     return green_ctx_; | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   // Make this context current | ||||
|   void makeCurrent() { | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|     auto current_stream = c10::cuda::getCurrentCUDAStream(); | ||||
|     parent_stream_ = current_stream.stream(); | ||||
|  | ||||
|     at::cuda::CUDAEvent ev; | ||||
|     ev.record(current_stream); | ||||
|  | ||||
|     CUcontext current = nullptr; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||
|     if (!current) { | ||||
|       C10_CUDA_DRIVER_CHECK( | ||||
|           c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_)); | ||||
|     } else { | ||||
|       C10_CUDA_DRIVER_CHECK( | ||||
|           c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_)); | ||||
|     } | ||||
|     // currently hardcodes the new green context to use the default stream | ||||
|     // TODO(eqy): consider creating a new stream if e.g., it allows interop | ||||
|     // with CUDA Graph captures etc. | ||||
|     auto default_stream = c10::cuda::getDefaultCUDAStream(); | ||||
|     ev.block(default_stream); | ||||
|     c10::cuda::setCurrentCUDAStream(default_stream); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   void popCurrent() { | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|     // see above note about stream being hardcoded to the default stream | ||||
|     at::cuda::CUDAEvent ev; | ||||
|     ev.record(c10::cuda::getCurrentCUDAStream()); | ||||
|     CUcontext popped; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped)); | ||||
|     TORCH_INTERNAL_ASSERT( | ||||
|         popped == context_, "expected popped context to be the current ctx"); | ||||
|     ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_)); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|  private: | ||||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) | ||||
|   int device_id_ = -1; | ||||
|   CUgreenCtx green_ctx_ = nullptr; | ||||
|   CUcontext context_ = nullptr; | ||||
|   cudaStream_t parent_stream_ = nullptr; | ||||
| #endif | ||||
| }; | ||||
| @ -35,6 +35,7 @@ from .graphs import ( | ||||
|     is_current_stream_capturing, | ||||
|     make_graphed_callables, | ||||
| ) | ||||
| from .green_contexts import GreenContext | ||||
| from .streams import Event, ExternalStream, Stream | ||||
|  | ||||
|  | ||||
| @ -1831,6 +1832,7 @@ __all__ = [ | ||||
|     "ExternalStream", | ||||
|     "Stream", | ||||
|     "StreamContext", | ||||
|     "GreenContext", | ||||
|     "amp", | ||||
|     "caching_allocator_alloc", | ||||
|     "caching_allocator_delete", | ||||
|  | ||||
							
								
								
									
										42
									
								
								torch/cuda/green_contexts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								torch/cuda/green_contexts.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | ||||
| import torch | ||||
|  | ||||
|  | ||||
| _GreenContext = object | ||||
| SUPPORTED = False | ||||
|  | ||||
| if hasattr(torch._C, "GreenContext"): | ||||
|     _GreenContext = torch._C.GreenContext  # type: ignore[misc] | ||||
|     SUPPORTED = True | ||||
|  | ||||
|  | ||||
| # Python shim helps Sphinx process docstrings more reliably. | ||||
| class GreenContext(_GreenContext): | ||||
|     r"""Wrapper around a CUDA green context. | ||||
|  | ||||
|     .. warning:: | ||||
|        This API is in beta and may change in future releases. | ||||
|     """ | ||||
|  | ||||
|     @staticmethod | ||||
|     def create(num_sms: int, device_id: int = 0) -> _GreenContext: | ||||
|         r"""Create a CUDA green context. | ||||
|  | ||||
|         Arguments: | ||||
|             num_sms (int): The number of SMs to use in the green context. | ||||
|             device_id (int, optional): The device index of green context. | ||||
|         """ | ||||
|         if not SUPPORTED: | ||||
|             raise RuntimeError("PyTorch was not built with Green Context support!") | ||||
|         return _GreenContext.create(num_sms, device_id)  # type: ignore[attr-defined] | ||||
|  | ||||
|     # Note that these functions are bypassed by we define them here | ||||
|     # for Sphinx documentation purposes | ||||
|     def make_current(self) -> None: | ||||
|         r"""Make the green context the current context.""" | ||||
|         return super().make_current()  # type: ignore[misc] | ||||
|  | ||||
|     def pop_current(self) -> None: | ||||
|         r"""Assuming the green context is the current context, pop it from the | ||||
|         context stack and restore the previous context. | ||||
|         """ | ||||
|         return super().pop_current()  # type: ignore[misc] | ||||
		Reference in New Issue
	
	Block a user