Compare commits

...

32 Commits

Author SHA1 Message Date
e14de25737 try to fix docs 2025-10-01 16:42:28 +00:00
3d5d978f61 lint 2025-09-30 22:08:53 +00:00
0eac4a6bc8 dats crazy 2025-09-30 21:55:02 +00:00
a1f0fea63b try to fix imports 2025-09-30 15:29:45 +00:00
ef0b38fa29 try fix rocm build and docs build 2025-09-29 22:25:16 +00:00
e3a937fd96 fix tensorpipe 2025-09-29 20:30:40 +00:00
7d3c70f1b2 lint 2025-09-29 20:19:49 +00:00
c42576e423 fix 2025-09-29 20:19:49 +00:00
c48257924b remove raw api 2025-09-29 20:19:49 +00:00
6b05c6d31a update 2025-09-29 20:19:49 +00:00
87bf69148b fix import? 2025-09-29 20:19:49 +00:00
f43f5a0c54 fix 2025-09-29 20:19:49 +00:00
f5a58491c6 lint 2025-09-29 20:19:49 +00:00
33392559bc fix 2025-09-29 20:19:49 +00:00
ce118e335d cleanup 2025-09-29 20:19:49 +00:00
1e73836a0b cleanup 2025-09-29 20:19:49 +00:00
861173011b lint 2025-09-29 20:19:49 +00:00
90c037db00 lint 2025-09-29 20:19:49 +00:00
b275e5c076 address comments 2025-09-29 20:19:49 +00:00
a7a6a86aa3 small fixes 2025-09-29 20:19:49 +00:00
840037d310 check in test 2025-09-29 20:19:49 +00:00
c6d366f76d lint 2025-09-29 20:19:49 +00:00
0ea59bb572 cleanup 2025-09-29 20:19:49 +00:00
bfce83d3a7 wip 2025-09-29 20:19:49 +00:00
c1056832b1 wip 2025-09-29 20:19:49 +00:00
c3d7a11f85 more broken 2025-09-29 20:19:49 +00:00
70f81fd27d wip 2025-09-29 20:19:49 +00:00
256e47f038 push broken 2025-09-29 20:19:49 +00:00
6510148380 fix build 2025-09-29 20:19:49 +00:00
cca5ee297a Update
[ghstack-poisoned]
2025-09-29 20:19:49 +00:00
3c3a739e13 Update
[ghstack-poisoned]
2025-09-29 20:19:49 +00:00
4c44f56905 Update (base update)
[ghstack-poisoned]
2025-09-29 20:19:49 +00:00
8 changed files with 328 additions and 1 deletions

View File

@ -51,6 +51,17 @@
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
_(cuCtxFromGreenCtx, 12080) \
_(cuCtxGetCurrent, 12080) \
_(cuCtxPopCurrent, 12080) \
_(cuCtxPushCurrent, 12080) \
_(cuCtxSetCurrent, 12080) \
_(cuGreenCtxCreate, 12080) \
_(cuGreenCtxDestroy, 12080) \
_(cuDevSmResourceSplitByCount, 12080) \
_(cuDeviceGet, 12080) \
_(cuDeviceGetDevResource, 12080) \
_(cuDevResourceGenerateDesc, 12080) \
_(cuMulticastAddDevice, 12030) \
_(cuMulticastBindMem, 12030) \
_(cuMulticastCreate, 12030) \

View File

@ -262,6 +262,20 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
```
## Green Contexts (experimental)
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
to enable more general carveout of SM resources for CUDA kernels.
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
```{eval-rst}
.. currentmodule:: torch.cuda.green_context
```
% This module needs to be documented. Adding here in the meantime
% for tracking purposes
@ -299,4 +313,4 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
:hidden:
cuda.aliases.md
```
```

View File

@ -5,6 +5,7 @@ import json
import math
import re
import tempfile
import time
import unittest
from itertools import product
from functools import partial
@ -1713,6 +1714,27 @@ class TestFP8Matmul(TestCase):
self.assertNotEqual(no_carveout, carveout_66)
self.assertNotEqual(carveout_66, carveout_0)
@unittest.skipIf(not _get_torch_cuda_version() >= (12, 8), "Green Context only tested on 12.8+")
def test_greencontext_carveout(self):
a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
ctx = torch.cuda.green_contexts.GreenContext.create(1, 0)
ctx.make_current()
torch.matmul(a, a)
torch.cuda.synchronize()
t0 = time.perf_counter()
partial_res = torch.matmul(a, a)
torch.cuda.synchronize()
t1 = time.perf_counter()
ctx.pop_current()
torch.matmul(a, a)
torch.cuda.synchronize()
t2 = time.perf_counter()
full_res = torch.matmul(a, a)
torch.cuda.synchronize()
t3 = time.perf_counter()
self.assertEqual(partial_res, full_res)
self.assertGreater(t1 - t0, t3 - t2)
def test_pack_uint4(self):
"""
Verify that given a tensor with high precision values [val0, val1],

View File

@ -2775,3 +2775,17 @@ class _StaticCudaLauncher:
args: tuple[Any, ...],
stream: _int,
) -> None: ...
# Defined in torch/csrc/cuda/green_context.h
class GreenContext:
@staticmethod
def create(
num_sms: _int,
device_id: _int,
) -> GreenContext: ...
def make_current(
self,
) -> None: ...
def pop_current(
self,
) -> None: ...

View File

@ -23,6 +23,7 @@
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <torch/csrc/cuda/green_context.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#ifdef USE_NCCL
@ -1491,6 +1492,13 @@ static void registerCudaPluggableAllocator(PyObject* module) {
addStorageDeleterFns(storages_to_add_deleters_to, delta);
});
}
static void initGreenContext(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<GreenContext>(m, "GreenContext")
.def_static("create", &GreenContext::create)
.def("make_current", &GreenContext::makeCurrent)
.def("pop_current", &GreenContext::popCurrent);
}
static void bindGetDeviceProperties(PyObject* module) {
// Add method to torch.cuda
@ -2215,6 +2223,7 @@ void initModule(PyObject* module) {
registerCudaDeviceProperties(module);
registerCudaPluggableAllocator(module);
initCudaMethodBindings(module);
initGreenContext(module);
}
} // namespace torch::cuda

View File

@ -0,0 +1,213 @@
#pragma once
#include <ATen/cuda/CUDAEvent.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM)
#include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#endif
class GreenContext {
public:
GreenContext(int device_id, unsigned int num_sms) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
static std::unique_ptr<GreenContext> create(
unsigned int num_sms,
std::optional<unsigned int> device_id) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
if (!device_id.has_value()) {
device_id = at::cuda::current_device();
}
return std::make_unique<GreenContext>(device_id.value(), num_sms);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext& operator=(GreenContext&& other) noexcept {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
if (this != &other) {
// Clean up current resources
if (green_ctx_) {
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (current == context_) {
TORCH_CHECK(
false,
"attempting to overwrite current green ctx "
"when it is active!");
}
C10_CUDA_DRIVER_CHECK(cuGreenCtxDestroy(green_ctx_));
}
// Take ownership of other's resources
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
}
return *this;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
~GreenContext() noexcept {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying CUDA context
CUcontext getContext() const {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
CUgreenCtx getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current
void makeCurrent() {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream();
at::cuda::CUDAEvent ev;
ev.record(current_stream);
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (!current) {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
} else {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
}
// currently hardcodes the new green context to use the default stream
// TODO(eqy): consider creating a new stream if e.g., it allows interop
// with CUDA Graph captures etc.
auto default_stream = c10::cuda::getDefaultCUDAStream();
ev.block(default_stream);
c10::cuda::setCurrentCUDAStream(default_stream);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
void popCurrent() {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
// see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream());
CUcontext popped;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
TORCH_INTERNAL_ASSERT(
popped == context_, "expected popped context to be the current ctx");
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
private:
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
int device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
#endif
};

View File

@ -34,6 +34,7 @@ from .graphs import (
is_current_stream_capturing,
make_graphed_callables,
)
from .green_contexts import GreenContext
from .streams import Event, ExternalStream, Stream
@ -1830,6 +1831,7 @@ __all__ = [
"ExternalStream",
"Stream",
"StreamContext",
"GreenContext",
"amp",
"caching_allocator_alloc",
"caching_allocator_delete",

View File

@ -0,0 +1,42 @@
import torch
_GreenContext = object
SUPPORTED = False
if hasattr(torch._C, "GreenContext"):
_GreenContext = torch._C.GreenContext # type: ignore[misc]
SUPPORTED = True
# Python shim helps Sphinx process docstrings more reliably.
class GreenContext(_GreenContext):
r"""Wrapper around a CUDA green context.
.. warning::
This API is in beta and may change in future releases.
"""
@staticmethod
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
r"""Create a CUDA green context.
Arguments:
num_sms (int): The number of SMs to use in the green context.
device_id (int, optional): The device index of green context.
"""
if not SUPPORTED:
raise RuntimeError("PyTorch was not built with Green Context support!")
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
# Note that these functions are bypassed by we define them here
# for Sphinx documentation purposes
def make_current(self) -> None:
r"""Make the green context the current context."""
return super().make_current() # type: ignore[misc]
def pop_current(self) -> None:
r"""Assuming the green context is the current context, pop it from the
context stack and restore the previous context.
"""
return super().pop_current() # type: ignore[misc]