diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 8910e581a1a4..380e7939ff76 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -51,6 +51,17 @@ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ + _(cuCtxFromGreenCtx, 12080) \ + _(cuCtxGetCurrent, 12080) \ + _(cuCtxPopCurrent, 12080) \ + _(cuCtxPushCurrent, 12080) \ + _(cuCtxSetCurrent, 12080) \ + _(cuGreenCtxCreate, 12080) \ + _(cuGreenCtxDestroy, 12080) \ + _(cuDevSmResourceSplitByCount, 12080) \ + _(cuDeviceGet, 12080) \ + _(cuDeviceGetDevResource, 12080) \ + _(cuDevResourceGenerateDesc, 12080) \ _(cuMulticastAddDevice, 12030) \ _(cuMulticastBindMem, 12030) \ _(cuMulticastCreate, 12030) \ diff --git a/docs/source/cuda.md b/docs/source/cuda.md index 09cf443cf067..26870c3dcc3f 100644 --- a/docs/source/cuda.md +++ b/docs/source/cuda.md @@ -262,6 +262,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t ``` +## Green Contexts (experimental) + +`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs +to enable more general carveout of SM resources for CUDA kernels. + +These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8. + +See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these. + +```{eval-rst} +.. currentmodule:: torch.cuda.green_contexts +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + GreenContext +``` + + % This module needs to be documented. Adding here in the meantime % for tracking purposes @@ -274,6 +296,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t .. py:module:: torch.cuda.gds ``` +```{eval-rst} +.. py:module:: torch.cuda.green_contexts +``` + ```{eval-rst} .. py:module:: torch.cuda.jiterator ``` @@ -299,4 +325,4 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t :hidden: cuda.aliases.md -``` \ No newline at end of file +``` diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index b1f7f91a34de..fa7e761c2497 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1,6 +1,7 @@ # Owner(s): ["module: linear algebra"] import contextlib +import time import unittest from itertools import product from functools import partial @@ -846,6 +847,28 @@ class TestMatmulCuda(InductorTestCase): op(a, mismatch_batch_dim_b, out_dtype=torch.float32) + @unittest.skipIf(not _get_torch_cuda_version() >= (12, 8), "Green Context only tested on 12.8+") + def test_greencontext_carveout(self): + a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16) + ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) + ctx.make_current() + torch.matmul(a, a) + torch.cuda.synchronize() + t0 = time.perf_counter() + partial_res = torch.matmul(a, a) + torch.cuda.synchronize() + t1 = time.perf_counter() + ctx.pop_current() + torch.matmul(a, a) + torch.cuda.synchronize() + t2 = time.perf_counter() + full_res = torch.matmul(a, a) + torch.cuda.synchronize() + t3 = time.perf_counter() + self.assertEqual(partial_res, full_res) + self.assertGreater(t1 - t0, t3 - t2) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index fa21705c76e4..dff4e9c014c7 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -123,6 +123,7 @@ class TestPublicBindings(TestCase): "FutureType", "Generator", "GeneratorType", + "GreenContext", "get_autocast_cpu_dtype", "get_autocast_dtype", "get_autocast_ipu_dtype", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index b3c7621aaf88..29ca4af21de1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2777,3 +2777,17 @@ class _StaticCudaLauncher: args: tuple[Any, ...], stream: _int, ) -> None: ... + +# Defined in torch/csrc/cuda/green_context.h +class GreenContext: + @staticmethod + def create( + num_sms: _int, + device_id: _int, + ) -> GreenContext: ... + def make_current( + self, + ) -> None: ... + def pop_current( + self, + ) -> None: ... diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 84fd3d0d714e..10db3924da78 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #ifdef USE_NCCL @@ -1491,6 +1492,13 @@ static void registerCudaPluggableAllocator(PyObject* module) { addStorageDeleterFns(storages_to_add_deleters_to, delta); }); } +static void initGreenContext(PyObject* module) { + auto m = py::handle(module).cast(); + py::class_(m, "GreenContext") + .def_static("create", &GreenContext::create) + .def("make_current", &GreenContext::makeCurrent) + .def("pop_current", &GreenContext::popCurrent); +} static void bindGetDeviceProperties(PyObject* module) { // Add method to torch.cuda @@ -2215,6 +2223,7 @@ void initModule(PyObject* module) { registerCudaDeviceProperties(module); registerCudaPluggableAllocator(module); initCudaMethodBindings(module); + initGreenContext(module); } } // namespace torch::cuda diff --git a/torch/csrc/cuda/green_context.h b/torch/csrc/cuda/green_context.h new file mode 100644 index 000000000000..80a39cd3c21e --- /dev/null +++ b/torch/csrc/cuda/green_context.h @@ -0,0 +1,213 @@ +#pragma once +#include +#if defined(CUDA_VERSION) && !defined(USE_ROCM) +#include +#include +#include +#include +#include +#endif + +class GreenContext { + public: + GreenContext(int device_id, unsigned int num_sms) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + int driver_version; + C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); + TORCH_CHECK( + driver_version >= 12080, "cuda driver too old to use green context!"); + CUcontext pctx = nullptr; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); + if (C10_UNLIKELY(!pctx)) { + TORCH_WARN( + "Attempted to create a green context but" + " there was no primary context! Creating a primary context..."); + + cudaFree(0); + } + + CUdevice device; + device_id_ = device_id; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); + + // Get device resources + CUdevResource device_resource; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( + device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); + + // Split resources + std::vector result(1); + auto result_data = result.data(); + unsigned int nb_groups = 1; + CUdevResource remaining; + + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( + result_data, + &nb_groups, + &device_resource, + &remaining, + 0, // default flags + num_sms)); + + TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); + + // Generate resource descriptor + CUdevResourceDesc desc; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( + &desc, result_data, 1)); + + // Create green context + // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( + &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); + + // Convert to regular context + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); + TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + static std::unique_ptr create( + unsigned int num_sms, + std::optional device_id) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + if (!device_id.has_value()) { + device_id = at::cuda::current_device(); + } + return std::make_unique(device_id.value(), num_sms); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + // Delete copy constructor and assignment + GreenContext(const GreenContext&) = delete; + GreenContext& operator=(const GreenContext&) = delete; + + // Implement move operations + GreenContext(GreenContext&& other) noexcept { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + device_id_ = std::exchange(other.device_id_, -1); + green_ctx_ = std::exchange(other.green_ctx_, nullptr); + context_ = std::exchange(other.context_, nullptr); + parent_stream_ = std::exchange(other.parent_stream_, nullptr); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + GreenContext& operator=(GreenContext&& other) noexcept { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + if (this != &other) { + // Clean up current resources + if (green_ctx_) { + CUcontext current = nullptr; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); + if (current == context_) { + TORCH_CHECK( + false, + "attempting to overwrite current green ctx " + "when it is active!"); + } + C10_CUDA_DRIVER_CHECK(cuGreenCtxDestroy(green_ctx_)); + } + + // Take ownership of other's resources + device_id_ = std::exchange(other.device_id_, -1); + green_ctx_ = std::exchange(other.green_ctx_, nullptr); + context_ = std::exchange(other.context_, nullptr); + parent_stream_ = std::exchange(other.parent_stream_, nullptr); + } + return *this; +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + ~GreenContext() noexcept { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + // Get the underlying CUDA context + CUcontext getContext() const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + return context_; +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + // Get the underlying green context +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + CUgreenCtx getGreenContext() const { + return green_ctx_; + } +#endif + + // Make this context current + void makeCurrent() { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + auto current_stream = c10::cuda::getCurrentCUDAStream(); + parent_stream_ = current_stream.stream(); + + at::cuda::CUDAEvent ev; + ev.record(current_stream); + + CUcontext current = nullptr; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); + if (!current) { + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_)); + } else { + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_)); + } + // currently hardcodes the new green context to use the default stream + // TODO(eqy): consider creating a new stream if e.g., it allows interop + // with CUDA Graph captures etc. + auto default_stream = c10::cuda::getDefaultCUDAStream(); + ev.block(default_stream); + c10::cuda::setCurrentCUDAStream(default_stream); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + void popCurrent() { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + // see above note about stream being hardcoded to the default stream + at::cuda::CUDAEvent ev; + ev.record(c10::cuda::getCurrentCUDAStream()); + CUcontext popped; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped)); + TORCH_INTERNAL_ASSERT( + popped == context_, "expected popped context to be the current ctx"); + ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_)); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + private: +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) + int device_id_ = -1; + CUgreenCtx green_ctx_ = nullptr; + CUcontext context_ = nullptr; + cudaStream_t parent_stream_ = nullptr; +#endif +}; diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index bf562b68f733..954d4d9ff589 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -35,6 +35,7 @@ from .graphs import ( is_current_stream_capturing, make_graphed_callables, ) +from .green_contexts import GreenContext from .streams import Event, ExternalStream, Stream @@ -1831,6 +1832,7 @@ __all__ = [ "ExternalStream", "Stream", "StreamContext", + "GreenContext", "amp", "caching_allocator_alloc", "caching_allocator_delete", diff --git a/torch/cuda/green_contexts.py b/torch/cuda/green_contexts.py new file mode 100644 index 000000000000..743dd323b5a2 --- /dev/null +++ b/torch/cuda/green_contexts.py @@ -0,0 +1,42 @@ +import torch + + +_GreenContext = object +SUPPORTED = False + +if hasattr(torch._C, "GreenContext"): + _GreenContext = torch._C.GreenContext # type: ignore[misc] + SUPPORTED = True + + +# Python shim helps Sphinx process docstrings more reliably. +class GreenContext(_GreenContext): + r"""Wrapper around a CUDA green context. + + .. warning:: + This API is in beta and may change in future releases. + """ + + @staticmethod + def create(num_sms: int, device_id: int = 0) -> _GreenContext: + r"""Create a CUDA green context. + + Arguments: + num_sms (int): The number of SMs to use in the green context. + device_id (int, optional): The device index of green context. + """ + if not SUPPORTED: + raise RuntimeError("PyTorch was not built with Green Context support!") + return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined] + + # Note that these functions are bypassed by we define them here + # for Sphinx documentation purposes + def make_current(self) -> None: + r"""Make the green context the current context.""" + return super().make_current() # type: ignore[misc] + + def pop_current(self) -> None: + r"""Assuming the green context is the current context, pop it from the + context stack and restore the previous context. + """ + return super().pop_current() # type: ignore[misc]