[CUDA] Add experimental green context support for SM carveout (#159104)

Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here...

Built on top of @drisspg 's branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104
Approved by: https://github.com/ngimel

Co-authored-by: drisspg <drisspguessous@gmail.com>
This commit is contained in:
Eddie Yan
2025-10-03 18:59:12 +00:00
committed by PyTorch MergeBot
parent 7eb1eb4313
commit 3c59351c6e
9 changed files with 342 additions and 1 deletions

View File

@ -51,6 +51,17 @@
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
_(cuCtxFromGreenCtx, 12080) \
_(cuCtxGetCurrent, 12080) \
_(cuCtxPopCurrent, 12080) \
_(cuCtxPushCurrent, 12080) \
_(cuCtxSetCurrent, 12080) \
_(cuGreenCtxCreate, 12080) \
_(cuGreenCtxDestroy, 12080) \
_(cuDevSmResourceSplitByCount, 12080) \
_(cuDeviceGet, 12080) \
_(cuDeviceGetDevResource, 12080) \
_(cuDevResourceGenerateDesc, 12080) \
_(cuMulticastAddDevice, 12030) \
_(cuMulticastBindMem, 12030) \
_(cuMulticastCreate, 12030) \

View File

@ -262,6 +262,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
```
## Green Contexts (experimental)
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
to enable more general carveout of SM resources for CUDA kernels.
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
```{eval-rst}
.. currentmodule:: torch.cuda.green_contexts
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
GreenContext
```
% This module needs to be documented. Adding here in the meantime
% for tracking purposes
@ -274,6 +296,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
.. py:module:: torch.cuda.gds
```
```{eval-rst}
.. py:module:: torch.cuda.green_contexts
```
```{eval-rst}
.. py:module:: torch.cuda.jiterator
```
@ -299,4 +325,4 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
:hidden:
cuda.aliases.md
```
```

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: linear algebra"]
import contextlib
import time
import unittest
from itertools import product
from functools import partial
@ -846,6 +847,28 @@ class TestMatmulCuda(InductorTestCase):
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
@unittest.skipIf(not _get_torch_cuda_version() >= (12, 8), "Green Context only tested on 12.8+")
def test_greencontext_carveout(self):
a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
ctx = torch.cuda.green_contexts.GreenContext.create(1, 0)
ctx.make_current()
torch.matmul(a, a)
torch.cuda.synchronize()
t0 = time.perf_counter()
partial_res = torch.matmul(a, a)
torch.cuda.synchronize()
t1 = time.perf_counter()
ctx.pop_current()
torch.matmul(a, a)
torch.cuda.synchronize()
t2 = time.perf_counter()
full_res = torch.matmul(a, a)
torch.cuda.synchronize()
t3 = time.perf_counter()
self.assertEqual(partial_res, full_res)
self.assertGreater(t1 - t0, t3 - t2)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")

View File

@ -123,6 +123,7 @@ class TestPublicBindings(TestCase):
"FutureType",
"Generator",
"GeneratorType",
"GreenContext",
"get_autocast_cpu_dtype",
"get_autocast_dtype",
"get_autocast_ipu_dtype",

View File

@ -2777,3 +2777,17 @@ class _StaticCudaLauncher:
args: tuple[Any, ...],
stream: _int,
) -> None: ...
# Defined in torch/csrc/cuda/green_context.h
class GreenContext:
@staticmethod
def create(
num_sms: _int,
device_id: _int,
) -> GreenContext: ...
def make_current(
self,
) -> None: ...
def pop_current(
self,
) -> None: ...

View File

@ -23,6 +23,7 @@
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <torch/csrc/cuda/green_context.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#ifdef USE_NCCL
@ -1491,6 +1492,13 @@ static void registerCudaPluggableAllocator(PyObject* module) {
addStorageDeleterFns(storages_to_add_deleters_to, delta);
});
}
static void initGreenContext(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<GreenContext>(m, "GreenContext")
.def_static("create", &GreenContext::create)
.def("make_current", &GreenContext::makeCurrent)
.def("pop_current", &GreenContext::popCurrent);
}
static void bindGetDeviceProperties(PyObject* module) {
// Add method to torch.cuda
@ -2215,6 +2223,7 @@ void initModule(PyObject* module) {
registerCudaDeviceProperties(module);
registerCudaPluggableAllocator(module);
initCudaMethodBindings(module);
initGreenContext(module);
}
} // namespace torch::cuda

View File

@ -0,0 +1,213 @@
#pragma once
#include <ATen/cuda/CUDAEvent.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM)
#include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#endif
class GreenContext {
public:
GreenContext(int device_id, unsigned int num_sms) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
static std::unique_ptr<GreenContext> create(
unsigned int num_sms,
std::optional<unsigned int> device_id) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
if (!device_id.has_value()) {
device_id = at::cuda::current_device();
}
return std::make_unique<GreenContext>(device_id.value(), num_sms);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext& operator=(GreenContext&& other) noexcept {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
if (this != &other) {
// Clean up current resources
if (green_ctx_) {
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (current == context_) {
TORCH_CHECK(
false,
"attempting to overwrite current green ctx "
"when it is active!");
}
C10_CUDA_DRIVER_CHECK(cuGreenCtxDestroy(green_ctx_));
}
// Take ownership of other's resources
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
}
return *this;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
~GreenContext() noexcept {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying CUDA context
CUcontext getContext() const {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
CUgreenCtx getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current
void makeCurrent() {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream();
at::cuda::CUDAEvent ev;
ev.record(current_stream);
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (!current) {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
} else {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
}
// currently hardcodes the new green context to use the default stream
// TODO(eqy): consider creating a new stream if e.g., it allows interop
// with CUDA Graph captures etc.
auto default_stream = c10::cuda::getDefaultCUDAStream();
ev.block(default_stream);
c10::cuda::setCurrentCUDAStream(default_stream);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
void popCurrent() {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
// see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream());
CUcontext popped;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
TORCH_INTERNAL_ASSERT(
popped == context_, "expected popped context to be the current ctx");
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
private:
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
int device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
#endif
};

View File

@ -35,6 +35,7 @@ from .graphs import (
is_current_stream_capturing,
make_graphed_callables,
)
from .green_contexts import GreenContext
from .streams import Event, ExternalStream, Stream
@ -1831,6 +1832,7 @@ __all__ = [
"ExternalStream",
"Stream",
"StreamContext",
"GreenContext",
"amp",
"caching_allocator_alloc",
"caching_allocator_delete",

View File

@ -0,0 +1,42 @@
import torch
_GreenContext = object
SUPPORTED = False
if hasattr(torch._C, "GreenContext"):
_GreenContext = torch._C.GreenContext # type: ignore[misc]
SUPPORTED = True
# Python shim helps Sphinx process docstrings more reliably.
class GreenContext(_GreenContext):
r"""Wrapper around a CUDA green context.
.. warning::
This API is in beta and may change in future releases.
"""
@staticmethod
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
r"""Create a CUDA green context.
Arguments:
num_sms (int): The number of SMs to use in the green context.
device_id (int, optional): The device index of green context.
"""
if not SUPPORTED:
raise RuntimeError("PyTorch was not built with Green Context support!")
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
# Note that these functions are bypassed by we define them here
# for Sphinx documentation purposes
def make_current(self) -> None:
r"""Make the green context the current context."""
return super().make_current() # type: ignore[misc]
def pop_current(self) -> None:
r"""Assuming the green context is the current context, pop it from the
context stack and restore the previous context.
"""
return super().pop_current() # type: ignore[misc]